In [2]:
from typing import Optional
import os
import numpy as np
import torch
from torchvision import transforms
import warnings
from decord import VideoReader, cpu
from torch.utils.data import Dataset

from augmentations import video_transforms as video_transforms
from augmentations import volume_transforms as volume_transforms
from augmentations.random_erasing import RandomErasing

class SSV2(Dataset):
    """
    Custom dataloader for SSV2. Given the path containing the SSV2 Dataset videos, annotations, and the mode (train, test, val), 
    return N uniformly sampled frames of each video. Adapted from: https://github.com/MCG-NJU/VideoMAE/blob/main/ssv2.py
    """

    def __init__(self, dataset_path, label_path, num_sample, 
                mode='train', crop_size=224, short_side_size=240,
                test_num_segment=5, test_num_crop=3, num_aug_sample=1, args=None, label_file=None):
        """
        Params:
        dataset_path: str
            Path to the folder containing SSv2 videos
        label_path: str
            Path to the folder containing test_videofolder.txt, train_videofolder.txt, val_videofolder.txt files
        num_sample: int
            Number of frames to sample 
        mode: str
            One of "train", "test", "val"
        crop_size: int
            Size to crop each frame
        short_side_size: int
            Size of the shorter side of the videos
        test_num_segment: int
            Number of temporal views to sample at test time
        test_num_crop: int
            Number of spatial views to sample at test time - uniformly sampled along the longer dimension
            of the input video
        num_aug_sample: int
            Number of times to augment the input video (repeated augmentation)
        args: argparse.Namespace
            Arguments to pass to the augmentations
        label_file: Optional[str]
            Prefix for the file containing the labels - if None, assumed to be equal to mode. Used for the case where we want to
            use validation/test view logic on the training set
        """
        paths = []
        labels = []
        num_frames = []

        if label_file is None:
            label_file = mode
        with open(os.path.join(label_path, f"{label_file}_videofolder.txt"), "r") as f:
            for line in f:
                vals = line.split()
                paths.append(f"{os.path.join(dataset_path, vals[0])}.webm")
                num_frames.append(int(vals[1]))
                labels.append(int(vals[2]))

        self.paths = paths
        self.labels = labels
        self.num_frames = num_frames
        self.short_side_size = short_side_size
        self.crop_size = crop_size
        self.mode = mode
        self.num_sample = num_sample
        self.test_num_segment = test_num_segment
        self.test_num_crop = test_num_crop
        self.num_aug_sample = num_aug_sample
        self.args = args

        if (mode == "train"):
            assert self.args != None, "Must pass arguments to augmentations"
            if self.args.reprob > 0:
                self.rand_erase = True

        elif mode == 'test' or mode == 'val':
            self.data_resize = transforms.Compose([
                video_transforms.Resize(size=(self.short_side_size), interpolation='bilinear')
            ])
            self.data_transform = transforms.Compose([
                volume_transforms.ClipToTensor(),
                video_transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                        std=[0.26862954, 0.26130258, 0.27577711])
            ])
            self.test_seg = []
            self.test_dataset = []
            self.test_label_array = []
            for idx in range(len(self.labels)):
                for ck in range(self.test_num_segment):
                    for cp in range(self.test_num_crop):
                            sample_label = self.labels[idx]
                            self.test_label_array.append(sample_label)
                            self.test_dataset.append(self.paths[idx])
                            self.test_seg.append((ck, cp))

    def __len__(self):
        if self.mode == "train":
            return len(self.paths)
        elif self.mode == "test" or self.mode == "val":
            return len(self.test_dataset)

    def __getitem__(self, index):
        if self.mode == 'train':
            sample = self.paths[index]
            buffer = self.loadvideo_decord(sample)
            if len(buffer) == 0:
                while len(buffer) == 0:
                    warnings.warn("video {} not correctly loaded during validation".format(sample))
                    index = np.random.randint(self.__len__())
                    sample = self.paths[index]
                    buffer = self.loadvideo_decord(sample)
            if self.num_aug_sample > 1:
                frame_list = []
                label_list = []
                index_list = []
                for _ in range(self.num_aug_sample):
                    new_frames = self._aug_frame(buffer, self.args)
                    frame_list.append(new_frames)
                    label_list.append(torch.nn.functional.one_hot(torch.LongTensor([self.labels[index]]), num_classes=174).squeeze())
                    index_list.append(sample.split("/")[-1].split(".")[0])                
                return {
                    "video_features": frame_list,
                    "labels": label_list,
                    "video_indices": index_list
                }
            else:
                buffer = self._aug_frame(buffer, self.args)
            return {
                "video_features": buffer, 
                "labels": torch.nn.functional.one_hot(torch.LongTensor([self.labels[index]]), num_classes=174).squeeze(),
                "video_indices": sample.split("/")[-1].split(".")[0]
            }

        elif self.mode == 'test' or self.mode == 'val':
            sample = self.test_dataset[index]
            chunk_nb, split_nb = self.test_seg[index]
            buffer = self.loadvideo_decord(sample)

            while len(buffer) == 0:
                warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
                    str(self.test_dataset[index]), chunk_nb, split_nb))
                index = np.random.randint(self.__len__())
                sample = self.test_dataset[index]
                chunk_nb, split_nb = self.test_seg[index]
                buffer = self.loadvideo_decord(sample)

            buffer = self.data_resize(buffer)
            if isinstance(buffer, list):
                buffer = np.stack(buffer, 0)

            spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
                                / (self.test_num_crop - 1)
            temporal_start = chunk_nb # in range(self.test_num_segment)
            spatial_start = int(split_nb * spatial_step)
            if buffer.shape[1] >= buffer.shape[2]:
                buffer = buffer[temporal_start::self.test_num_segment, \
                       spatial_start:spatial_start + self.short_side_size, :, :]
            else:
                buffer = buffer[temporal_start::self.test_num_segment, \
                       :, spatial_start:spatial_start + self.short_side_size, :]

            buffer = self.data_transform(buffer)            
            return {
                "video_features": buffer, 
                "labels": torch.nn.functional.one_hot(torch.LongTensor([self.test_label_array[index]]), num_classes=174).squeeze(),
                "video_indices": sample.split("/")[-1].split(".")[0], 
                "chunk_nbs": chunk_nb, "split_nbs": split_nb
            }
        else:
            raise NameError('mode {} unkown'.format(self.mode))

    def _aug_frame(
        self,
        buffer,
        args,
    ):
        aug_transform = video_transforms.create_random_augment(
            input_size=(self.crop_size, self.crop_size),
            auto_augment=args.aa,
            interpolation=args.train_interpolation,
        )

        buffer = [
            transforms.ToPILImage()(frame) for frame in buffer
        ]

        buffer = aug_transform(buffer)

        buffer = [transforms.ToTensor()(img) for img in buffer]
        buffer = torch.stack(buffer) # T C H W
        buffer = buffer.permute(0, 2, 3, 1) # T H W C 
        
        # T H W C 
        buffer = tensor_normalize(
            buffer, [0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]
        )
        # T H W C -> C T H W.
        buffer = buffer.permute(3, 0, 1, 2)
        # Perform data augmentation.
        scl, asp = (
            [0.08, 1.0],
            [0.75, 1.3333],
        )

        buffer = spatial_sampling(
            buffer,
            spatial_idx=-1,
            min_scale=256,
            max_scale=320,
            crop_size=self.crop_size,
            random_horizontal_flip=False,
            inverse_uniform_sampling=False,
            aspect_ratio=asp,
            scale=scl,
            motion_shift=False
        )

        if self.rand_erase:
            erase_transform = RandomErasing(
                args.reprob,
                mode=args.remode,
                max_count=args.recount,
                num_splits=args.recount,
                device="cpu",
            )
            buffer = buffer.permute(1, 0, 2, 3)
            buffer = erase_transform(buffer)
            buffer = buffer.permute(1, 0, 2, 3)

        return buffer

    def loadvideo_decord(self, sample, sample_rate_scale=1):
        """Load video content using Decord"""
        fname = sample

        if not (os.path.exists(fname)):
            return []

        # avoid hanging issue
        if os.path.getsize(fname) < 1 * 1024:
            print('SKIP: ', fname, " - ", os.path.getsize(fname))
            return []
        try:
            vr = VideoReader(fname, num_threads=1, ctx=cpu(0))
        except:
            print("video cannot be loaded by decord: ", fname)
            return []

        if self.mode == 'test' or self.mode == "val":
            all_index = []
            tick = len(vr) / float(self.num_sample)
            all_index = []
            for i in range(self.test_num_segment):
                all_index = all_index + [int((i * tick / float(self.test_num_segment)) + tick * x) \
                                for x in range(self.num_sample)]
            while len(all_index) < (self.num_sample * self.test_num_segment):
                all_index.append(all_index[-1])
            all_index = list(np.sort(np.array(all_index))) 
            vr.seek(0)
            buffer = vr.get_batch(all_index).asnumpy()
            return buffer

        # handle temporal segments
        average_duration = len(vr) // self.num_sample
        all_index = []
        if average_duration > 0:
            all_index += list(np.multiply(list(range(self.num_sample)), average_duration) + np.random.randint(average_duration,
                                                                                                        size=self.num_sample))
        elif len(vr) > self.num_sample:
            all_index += list(np.sort(np.random.randint(len(vr), size=self.num_sample)))
        else:
            all_index += list(np.zeros((self.num_sample,)))
        all_index = list(np.array(all_index)) 
        vr.seek(0)
        buffer = vr.get_batch(all_index).asnumpy()
        return buffer

def tensor_normalize(tensor, mean, std):
    """
    Normalize a given tensor by subtracting the mean and dividing the std.
    Args:
        tensor (tensor): tensor to normalize.
        mean (tensor or list): mean value to subtract.
        std (tensor or list): std to divide.
    """
    if tensor.dtype == torch.uint8:
        tensor = tensor.float()
        tensor = tensor / 255.0
    if type(mean) == list:
        mean = torch.tensor(mean)
    if type(std) == list:
        std = torch.tensor(std)
    tensor = tensor - mean
    tensor = tensor / std
    return tensor

def spatial_sampling(
    frames,
    spatial_idx=-1,
    min_scale=256,
    max_scale=320,
    crop_size=224,
    random_horizontal_flip=True,
    inverse_uniform_sampling=False,
    aspect_ratio=None,
    scale=None,
    motion_shift=False,
):
    """
    Perform spatial sampling on the given video frames. If spatial_idx is
    -1, perform random scale, random crop, and random flip on the given
    frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
    with the given spatial_idx.
    Args:
        frames (tensor): frames of images sampled from the video. The
            dimension is `num frames` x `height` x `width` x `channel`.
        spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
            or 2, perform left, center, right crop if width is larger than
            height, and perform top, center, buttom crop if height is larger
            than width.
        min_scale (int): the minimal size of scaling.
        max_scale (int): the maximal size of scaling.
        crop_size (int): the size of height and width used to crop the
            frames.
        inverse_uniform_sampling (bool): if True, sample uniformly in
            [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
            scale. If False, take a uniform sample from [min_scale,
            max_scale].
        aspect_ratio (list): Aspect ratio range for resizing.
        scale (list): Scale range for resizing.
        motion_shift (bool): Whether to apply motion shift for resizing.
    Returns:
        frames (tensor): spatially sampled frames.
    """
    assert spatial_idx in [-1, 0, 1, 2]
    if spatial_idx == -1:
        if aspect_ratio is None and scale is None:
            frames, _ = video_transforms.random_short_side_scale_jitter(
                images=frames,
                min_size=min_scale,
                max_size=max_scale,
                inverse_uniform_sampling=inverse_uniform_sampling,
            )
            frames, _ = video_transforms.random_crop(frames, crop_size)
        else:
            transform_func = (
                video_transforms.random_resized_crop_with_shift
                if motion_shift
                else video_transforms.random_resized_crop
            )
            frames = transform_func(
                images=frames,
                target_height=crop_size,
                target_width=crop_size,
                scale=scale,
                ratio=aspect_ratio,
            )
        if random_horizontal_flip:
            frames, _ = video_transforms.horizontal_flip(0.5, frames)
    else:
        # The testing is deterministic and no jitter should be performed.
        # min_scale, max_scale, and crop_size are expect to be the same.
        assert len({min_scale, max_scale, crop_size}) == 1
        frames, _ = video_transforms.random_short_side_scale_jitter(
            frames, min_scale, max_scale
        )
        frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx)
    return frames

In [3]:
from argparse import ArgumentParser

parser = ArgumentParser()

# training arguments
parser.add_argument('--batch_size', default=512, type=int)
parser.add_argument("--criterion_name", type=str, default="binary_crossentropy", choices=["binary_crossentropy"])
parser.add_argument('--balance_classes', default=False, type=lambda x: (str(x).lower() == 'true'))
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--wd', default=0.01, type=float, help="Weight decay (will use Adam if set to 0, AdamW otherwise).")
parser.add_argument('--gradient_clip_val', default=1, type=float)
parser.add_argument('--gpus', default=1, type=int)
parser.add_argument('--num_workers', default=0, type=int)
parser.add_argument('--seed', default=None, type=int)
parser.add_argument('--checkpoint_every_n_epochs', type=int, default=5)
parser.add_argument('--wandb_group', type=str, default="latest")
parser.add_argument('--freeze_backbone', default=False, type=lambda x: (str(x).lower() == 'true'))
parser.add_argument('--gradient_accumulation_steps', default=1, type=int)

# Augmentation params
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
                    help='Color jitter factor (default: 0.4)')
parser.add_argument('--num_aug_sample', type=int, default=2,
                    help='Repeated_aug (default: 2)')
parser.add_argument('--aa', type=str, default='rand-m7-n4-mstd0.5-inc1', metavar='NAME',
                    help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m7-n4-mstd0.5-inc1)'),
parser.add_argument('--train_interpolation', type=str, default='bicubic',
                    help='Training interpolation (random, bilinear, bicubic default: "bicubic")')

# Random Erase params
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
                    help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
                    help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
                    help='Random erase count (default: 1)')

# Mixup params
parser.add_argument('--mixup', type=float, default=0.8,
                    help='mixup alpha, mixup enabled if > 0.')
parser.add_argument('--cutmix', type=float, default=1.0,
                    help='cutmix alpha, cutmix enabled if > 0.')
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup_prob', type=float, default=1.0,
                    help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
                    help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup_mode', type=str, default='batch',
                    help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')

# dataset arguments
parser.add_argument('--task_name', type=str, required=True)
parser.add_argument('--data_path', type=str, required=True)
parser.add_argument('--label_path', type=str, required=True)
parser.add_argument('--n_frames', default=32, type=int)

# model structure arguments
parser.add_argument("--model_name", type=str, default="encode_pool_classify", choices=["encode_pool_classify"])

# backbone arguments
parser.add_argument("--backbone_name", type=str, default="clip_ViT-B/32")

# encoder arguments
parser.add_argument("--temporal_pooling_name", type=str, default="mean", choices=["mean", "transformer", "identity"])
# transformer specific arguments used if `temporal_pooling_name` is `transformer`
parser.add_argument('--temporal_pooling_transformer_depth', default=3, type=int)
parser.add_argument('--temporal_pooling_transformer_heads', default=4, type=int)
parser.add_argument('--temporal_pooling_transformer_dim', default=512, type=int)
parser.add_argument('--temporal_pooling_transformer_ff_dim', default=512, type=int)
parser.add_argument('--temporal_pooling_transformer_input_dim', default=512, type=int)
parser.add_argument('--temporal_pooling_transformer_emb_dropout', default=0.1, type=float)

# classifier arguments
parser.add_argument("--classification_layer_name", type=str, default="linear", choices=["linear"])
parser.add_argument("--classification_input_dim", type=int, default=512)
parser.add_argument("--num_classes", type=int, required=True)

args = parser.parse_args("--task_name ssv2 --data_path /svl/data/SomethingSomethingV2/20bn-something-something-v2/ \
    --label_path /vision/u/eatang/SSV2 --n_frames 16 --num_classes 174".split())


In [4]:
import numpy as np
import torch
from torch.utils.data._utils.collate import string_classes

# Fix imports, but keep backwards compatibility
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
    from torch._six import container_abcs, int_classes
else:
    import collections.abc as container_abcs
    int_classes = int

def collate_with_pad(batch, allow_pad=True, pad_right=True):
    r"""Puts each data field into a tensor with outer dimension batch size.
    Will pad with zeros if there are sequences of varying lenghts in the batch BUT ONLY IF seq is first dimension.
    Will pad on the right by default, except when `pad_right==False`.
    """
    # print(batch[0].keys())
    # print(batch[0].values())
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage).view(-1, *list(elem.size()))
        ###########################################
        # NEW: if tensors are different lengths PAD
        ###########################################
        it = iter(batch)
        elem_size = torch.Tensor.size(next(it))
        if not all(torch.Tensor.size(elem) == elem_size for elem in it):
            if allow_pad:
                # TRY TO PAD along the first dimension
                max_tensor_len = max(map(lambda tensor: tensor.size(1), batch))
                stacked_padded_tensors = torch.zeros(len(batch), max_tensor_len, elem_size[0], *elem_size[2:])
                for idx_in_batch, tensor in enumerate(batch):
                    if not pad_right:
                        stacked_padded_tensors[idx_in_batch, -1*tensor.size(1):, ...] = tensor.permute((1,0,2,3))   # pad on the left
                    else:
                        stacked_padded_tensors[idx_in_batch, :tensor.size(1), ...] = tensor.permute((1,0,2,3))      # pad on the right
                return stacked_padded_tensors.permute((0,2,1,3,4))
            else:
                raise RuntimeError('each element in list of batch should be of equal size')
        ###########################################
        # END NEW
        ###########################################
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
#             # array of string classes and object
#             if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
#                 raise TypeError(f"Unsupported type: {elem_type}")

            return collate_with_pad([torch.as_tensor(b) for b in batch], pad_right=pad_right)
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: collate_with_pad([d[key] for d in batch], pad_right=pad_right) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(collate_with_pad(samples, pad_right=pad_right) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [collate_with_pad(samples, pad_right=pad_right) for samples in transposed]

    raise TypeError(f"Unsupported type: {elem_type}")


In [5]:
from functools import partial

In [6]:

mode = "train"
dataset = SSV2("/svl/data/SomethingSomethingV2/20bn-something-something-v2/", 
            "/vision/u/eatang/SSV2", 16, mode=mode, 
            crop_size=224, 
            short_side_size=240,
            num_aug_sample=2,
            args=args,
            test_num_segment=2)
dataset = torch.utils.data.Subset(dataset, range(32))
train_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    collate_fn=partial(collate_with_pad, allow_pad=False, pad_right=True),
)


In [8]:
for i, batch in enumerate(train_loader):
    if i == 0:
        print(batch["video_features"])
        break

[tensor([[[[[ 4.2670e-01,  4.7623e-01,  4.9969e-01,  ...,  1.6055e+00,
             1.6530e+00,  1.6530e+00],
           [ 3.8584e-01,  4.1688e-01,  4.6321e-01,  ...,  1.6401e+00,
             1.6938e+00,  1.6938e+00],
           [ 2.6618e-01,  2.9590e-01,  3.4989e-01,  ...,  1.6217e+00,
             1.6704e+00,  1.6968e+00],
           ...,
           [ 5.4042e-01,  5.8442e-01,  6.0375e-01,  ...,  1.6840e+00,
             1.7386e+00,  1.7584e+00],
           [ 8.6465e-01,  8.6465e-01,  8.5200e-01,  ...,  1.7290e+00,
             1.7645e+00,  1.7844e+00],
           [ 8.6465e-01,  8.6465e-01,  8.6465e-01,  ...,  1.7416e+00,
             1.7645e+00,  1.7844e+00]],

          [[ 1.2150e+00,  1.0565e+00,  8.6622e-01,  ...,  1.7416e+00,
             1.7364e+00,  1.6968e+00],
           [ 2.2069e-01,  8.4783e-01,  1.0044e+00,  ...,  1.7625e+00,
             1.7811e+00,  1.7785e+00],
           [ 2.2714e-01,  8.9306e-01,  1.1505e+00,  ...,  1.7779e+00,
             1.8103e+00,  1.8103e+00],


In [13]:
batch["video_features"][0][0].sh

tensor([[[[ 4.2670e-01,  4.7623e-01,  4.9969e-01,  ...,  1.6055e+00,
            1.6530e+00,  1.6530e+00],
          [ 3.8584e-01,  4.1688e-01,  4.6321e-01,  ...,  1.6401e+00,
            1.6938e+00,  1.6938e+00],
          [ 2.6618e-01,  2.9590e-01,  3.4989e-01,  ...,  1.6217e+00,
            1.6704e+00,  1.6968e+00],
          ...,
          [ 5.4042e-01,  5.8442e-01,  6.0375e-01,  ...,  1.6840e+00,
            1.7386e+00,  1.7584e+00],
          [ 8.6465e-01,  8.6465e-01,  8.5200e-01,  ...,  1.7290e+00,
            1.7645e+00,  1.7844e+00],
          [ 8.6465e-01,  8.6465e-01,  8.6465e-01,  ...,  1.7416e+00,
            1.7645e+00,  1.7844e+00]],

         [[ 1.2150e+00,  1.0565e+00,  8.6622e-01,  ...,  1.7416e+00,
            1.7364e+00,  1.6968e+00],
          [ 2.2069e-01,  8.4783e-01,  1.0044e+00,  ...,  1.7625e+00,
            1.7811e+00,  1.7785e+00],
          [ 2.2714e-01,  8.9306e-01,  1.1505e+00,  ...,  1.7779e+00,
            1.8103e+00,  1.8103e+00],
          ...,
     