In [78]:
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import sklearn.metrics as metrics
import numpy as np

from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm


from data import ModelNet40, ScanObjectNN
from model import PointNet, DGCNN
from util import cal_loss, cal_loss_mix, IOStream
import gco

In [79]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

args = argparse.Namespace(batch_size=10, data='MN40', dropout=0.5, emb_dims=1024, epochs=50, eval=False, exp_name='SageMix', k=20, lr=0.0001, model='pointnet', model_path='', momentum=0.9, no_cuda=False, num_points=1024, seed=1, sigma=-1, test_batch_size=16, theta=0.2, use_sgd=True)

In [80]:
num_points = 1024
dataset = ModelNet40(partition='train', num_points=num_points)
batch_size=args.batch_size

test_batch_size = args.test_batch_size
train_loader = DataLoader(dataset, num_workers=8,
                        batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(ModelNet40(partition='test', num_points=num_points), num_workers=8,
                        batch_size=test_batch_size, shuffle=True, drop_last=False)
num_class=40

In [81]:
print('train data: ', len(dataset))

train data:  9840


In [82]:
if args.data == 'MN40':
    dataset = ModelNet40(partition='train', num_points=args.num_points)
    # args.batch_size = len(dataset)
    # args.batch_size = 
    #print('args.batch_size:',args.batch_size)
    train_loader = DataLoader(dataset, num_workers=8,
                            batch_size=args.batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=8,
                            batch_size=args.test_batch_size, shuffle=True, drop_last=False)
    num_class=40
elif args.data == 'SONN_easy':
    train_loader = DataLoader(ScanObjectNN(partition='train', num_points=args.num_points, ver="easy"), num_workers=8,
                            batch_size=args.batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(ScanObjectNN(partition='test', num_points=args.num_points, ver="easy"), num_workers=8,
                            batch_size=args.test_batch_size, shuffle=True, drop_last=False)
    num_class =15
elif args.data == 'SONN_hard':
    train_loader = DataLoader(ScanObjectNN(partition='train', num_points=args.num_points, ver="hard"), num_workers=8,
                            batch_size=args.batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(ScanObjectNN(partition='test', num_points=args.num_points, ver="hard"), num_workers=8,
                            batch_size=args.test_batch_size, shuffle=True, drop_last=False)
    num_class =15


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#Try to load models
if args.model == 'pointnet':
    model = PointNet(args, num_class).to(device)
elif args.model == 'dgcnn':
    model = DGCNN(args, num_class).to(device)
else:
    raise Exception("Not implemented")

In [83]:
from emd_ import emd_module
class SageMix:
    def __init__(self, args, num_class=40):
        self.num_class = num_class
        self.EMD = emd_module.emdModule()
        self.sigma = args.sigma
        self.beta = torch.distributions.beta.Beta(torch.tensor([args.theta]), torch.tensor([args.theta]))
        self.beta2 = torch.distributions.beta.Beta(torch.tensor([2*args.theta]), torch.tensor([args.theta]))

    
    def mix(self, xyz, label, saliency=None, mixing_idx=0, device='cuda:0'):
        """
        Args:
            xyz (B,N,3)
            label (B)
            saliency (B,N): Defaults to None.
        """        
        # #print("xyz shape", xyz.shape)
        B, N, _ = xyz.shape
        if mixing_idx == 0:
            idxs = torch.randperm(B)

            # #print(xyz)
            
            #Optimal assignment in Eq.(3)
            perm = xyz[idxs]
            
            _, ass = self.EMD(xyz, perm, 0.005, 500) # mapping
            ass = ass.long()
            # #print(ass)
            perm_new = torch.zeros_like(perm).to(device)#.cuda()
            perm_saliency = torch.zeros_like(saliency).to(device)#.cuda()
            
            # #print(ass,ass.shape)
            for i in range(B):
                perm_new[i] = perm[i][ass[i]]
                # #print(idxs)
                # #print(ass)
                # #print(saliency)
                # #print("idxs shape", idxs.shape)
                # #print("ass shape", ass.shape)
                # #print("saliency shape", saliency.shape)
                perm_saliency[i] = saliency[idxs][i][ass[i]]
            
            #####
            # Saliency-guided sequential sampling
            #####
            #Eq.(4) in the main paper
            saliency = saliency/saliency.sum(-1, keepdim=True)
            anc_idx = torch.multinomial(saliency, 1, replacement=True)
            anchor_ori = xyz[torch.arange(B), anc_idx[:,0]]
            
            #cal distance and reweighting saliency map for Eq.(5) in the main paper
            sub = perm_new - anchor_ori[:,None,:]
            dist = ((sub) ** 2).sum(2).sqrt()
            perm_saliency = perm_saliency * dist
            perm_saliency = perm_saliency/perm_saliency.sum(-1, keepdim=True)
            
            #Eq.(5) in the main paper
            anc_idx2 = torch.multinomial(perm_saliency, 1, replacement=True)
            anchor_perm = perm_new[torch.arange(B),anc_idx2[:,0]]
                    
                    
            #####
            # Shape-preserving continuous Mixup
            #####
            alpha = self.beta.sample((B,)).to(device)#.cuda()
            sub_ori = xyz - anchor_ori[:,None,:]
            sub_ori = ((sub_ori) ** 2).sum(2).sqrt()
            #Eq.(6) for first sample
            ker_weight_ori = torch.exp(-0.5 * (sub_ori ** 2) / (self.sigma ** 2))  #(M,N)
            
            sub_perm = perm_new - anchor_perm[:,None,:]
            sub_perm = ((sub_perm) ** 2).sum(2).sqrt()
            #Eq.(6) for second sample
            ker_weight_perm = torch.exp(-0.5 * (sub_perm ** 2) / (self.sigma ** 2))  #(M,N)
            
            #Eq.(9)
            weight_ori = ker_weight_ori * alpha 
            weight_perm = ker_weight_perm * (1-alpha)
            weight = (torch.cat([weight_ori[...,None],weight_perm[...,None]],-1)) + 1e-16
            weight = weight/weight.sum(-1)[...,None]

            #Eq.(8) for new sample
            x = weight[:,:,0:1] * xyz + weight[:,:,1:] * perm_new
            
            #Eq.(8) for new label
            target = weight.sum(1)
            target = target / target.sum(-1, keepdim=True)
            
            label_onehot = torch.zeros(B, self.num_class).to(device).scatter(1, label.view(-1, 1), 1)
            label_perm_onehot = label_onehot[idxs]
            label = target[:, 0, None] * label_onehot + target[:, 1, None] * label_perm_onehot
            return x, label
        
        else:
            # #print("xyz shape mixing 1", xyz.shape)
            B, N, _ = xyz.shape
            split_idx = int(B/2)
            # #print("split_idx", split_idx)
            # #print("saliency shape", saliency.shape)

            xyz1 = xyz[:split_idx]
            xyz2 = xyz[split_idx:]
            label1 = label[:split_idx]
            label2 = label[split_idx:]
            saliency1 = saliency[:split_idx]
            saliency2 = saliency[split_idx:]

            _, ass = self.EMD(xyz1, xyz2, 0.005, 500) # mapping
            ass = ass.long()

            #####
            # Saliency-guided sequential sampling
            #####
            #Eq.(4) in the main paper
            saliency1 = saliency1/saliency1.sum(-1, keepdim=True)
            anc_idx = torch.multinomial(saliency1, 1, replacement=True)
            anchor_ori = xyz1[torch.arange(split_idx), anc_idx[:,0]]

            #cal distance and reweighting saliency map for Eq.(5) in the main paper
            sub = xyz2 - anchor_ori[:,None,:]
            dist = ((sub) ** 2).sum(2).sqrt()
            # #print("saliency2 shape", saliency2.shape)
            # #print("dist shape", dist.shape)
            saliency2 = saliency2 * dist
            saliency2 = saliency2/saliency2.sum(-1, keepdim=True)
            
            #Eq.(5) in the main paper
            anc_idx2 = torch.multinomial(saliency2, 1, replacement=True)
            anchor_2 = xyz2[torch.arange(split_idx),anc_idx2[:,0]]

            alpha = self.beta.sample((split_idx,)).to(device)#.cuda()
            sub_ori = xyz1 - anchor_ori[:,None,:]
            sub_ori = ((sub_ori) ** 2).sum(2).sqrt()
            #Eq.(6) for first sample
            ker_weight_ori = torch.exp(-0.5 * (sub_ori ** 2) / (self.sigma ** 2))  #(M,N)

            # #print("anchor_2 shape", anchor_2.shape)
            sub_perm = xyz2 - anchor_2[:,None,:]
            sub_perm = ((sub_perm) ** 2).sum(2).sqrt()
            #Eq.(6) for second sample
            ker_weight_perm = torch.exp(-0.5 * (sub_perm ** 2) / (self.sigma ** 2))  #(M,N)

            #Eq.(9)
            weight_ori = ker_weight_ori * alpha
            weight_perm = ker_weight_perm * (1-alpha)
            weight = (torch.cat([weight_ori[...,None],weight_perm[...,None]],-1)) + 1e-16
            weight = weight/weight.sum(-1)[...,None]

            #Eq.(8) for new sample
            x = weight[:,:,0:1] * xyz1 + weight[:,:,1:] * xyz2

            #Eq.(8) for new label
            target = weight.sum(1)
            target = target / target.sum(-1, keepdim=True)

            label = target[:, 0, None] * label1 + target[:, 1, None] * label2



            return x, label
    

In [85]:
# #!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import sklearn.metrics as metrics
import numpy as np

from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm


from data import ModelNet40, ScanObjectNN
from model import PointNet, DGCNN
from util import cal_loss, cal_loss_mix, IOStream

io = IOStream('checkpoints/' + args.exp_name + '/run.log')
io.cprint(str(args))

if args.use_sgd:
    #print("Use SGD")
    opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4)
else:
    #print("Use Adam")
    opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)


best_test_acc = 0
sagemix=SageMix(args, num_class)
scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr)
criterion = cal_loss_mix


def interleave(a, b):
    # Check that a and b have the same shape
    assert a.shape == b.shape, "Tensors must have the same shape"
    
    # Expand dimensions
    a = a.unsqueeze(1)
    b = b.unsqueeze(1)

    # Concatenate tensors
    c = torch.cat((a, b), dim=1)

    # Reshape tensor
    c = c.view(-1, *a.shape[2:])

    return c
import open3d as o3d

for epoch in range(args.epochs):

    ####################
    # Train
    ####################
    train_loss = 0.0
    count = 0.0
    model.train()
    train_pred = []
    train_true = []
    for data, label in tqdm(train_loader):
        data, label = data.to(device), label.to(device).squeeze()
        # #print("data shape", data)
        batch_size = data.size()[0]
        split_idx = int(np.ceil(batch_size * 1/2))
        data01 = data[:split_idx, :, :]
        label01 = label[:split_idx]
        data2 = data[split_idx:, :, :]
        label2 = label[split_idx:]

        print(label01)
        #before mixing
        # for id,dat in enumerate(data01):
        #     pcd = o3d.geometry.PointCloud()
        #     pcd.points = o3d.utility.Vector3dVector(dat.cpu().numpy())
        #     o3d.io.write_point_cloud("clouds/before_mixing_{}.pcd".format(id), pcd)
        #     # o3d.visualization.draw_geometries([pcd])
        

        # for id,dat in enumerate(data2):
        #     pcd = o3d.geometry.PointCloud()
        #     pcd.points = o3d.utility.Vector3dVector(dat.cpu().numpy())
        #     o3d.io.write_point_cloud("clouds/before_mixing_second_stage_{}.pcd".format(id), pcd)

        
        print("data01:",data01.shape)
        print("data2:",data2.shape)
        ####################
        # generate augmented sample
        ####################
        model.eval()
        data_var = Variable(data.permute(0,2,1), requires_grad=True)
        logits = model(data_var)
        loss = cal_loss(logits, label, smoothing=False)
        loss.backward()
        opt.zero_grad()
        saliency = torch.sqrt(torch.mean(data_var.grad**2,1))
        # #print("saliency shape", saliency.shape)
        # #print("data01 shape", data01.shape)
        data_mix, label_mix = sagemix.mix(data01, label01, saliency[:split_idx,:], mixing_idx = 0)

        # for id,dat in enumerate(data_mix):
        #     pcd = o3d.geometry.PointCloud()
        #     pcd.points = o3d.utility.Vector3dVector(dat.cpu().numpy())
        #     o3d.io.write_point_cloud("clouds/after_mix_first_stage_{}.pcd".format(id), pcd)

        label2_onehot = torch.zeros(label2.shape[0], num_class).to(device).scatter(1, label2.view(-1, 1), 1)
        data_all = torch.cat((data_mix, data2), dim=0)
        # #print("data_all shape", data_all.shape)
        # label_all = interleave(label_mix, label2_onehot)
        label_all = torch.cat((label_mix, label2_onehot), dim=0)

        data_var = Variable(data_mix.permute(0,2,1), requires_grad=True)
        logits = model(data_var)
        loss_mix = criterion(logits, label_mix)
        loss_mix.backward()
        opt.zero_grad()
        saliency_mix = torch.sqrt(torch.mean(data_var.grad**2,1))

        

        # saliency_all = interleave(saliency_mix, saliency[split_idx:, :])
        saliency_all = torch.cat((saliency_mix, saliency[split_idx:,:]), dim=0)

        data_total_mix, label_total_mix = sagemix.mix(data_all, label_all, saliency_all, mixing_idx=1)
        # break
        # for id,dat in enumerate(data_total_mix):
        #     pcd = o3d.geometry.PointCloud()
        #     pcd.points = o3d.utility.Vector3dVector(dat.cpu().numpy())
        #     o3d.io.write_point_cloud("clouds/after_hierarchical_mixing_{}.pcd".format(id), pcd)
        break
        # #print(saliency_all[0,:])
        # #print(saliency_all[25,:])
        # #print("saliency_all shape", saliency_all.shape)
        # break

        # data_allmix, label_allmix = sagemix.mix(data_all, label_, saliency)
        # #print("label_all shape", label_all.shape)
        
        # mixed_saliency = torch.sqrt(torch.mean(data_var.grad**2,1))
        # #print("data shape", data.shape)
        # model.train()
        # break
                
            
        # data3, label3 = sagemix.mix(data2, label2, saliency2, mixing_idx=1)

        
        # mixed_saliency = torch.sqrt(torch.mean(data_var.grad**2,1))
        # #print("data shape", data.shape)
        model.train()
        # # break
            
        opt.zero_grad()
        # opt.zero_grad()
        logits = model(data_total_mix.permute(0,2,1))
        loss = criterion(logits, label_total_mix)
        loss.backward()
        opt.step()
        preds = logits.max(dim=1)[1]
        count += batch_size
        train_loss += loss.item() * batch_size
        # logits3 = model(data3.permute(0,2,1))
        # loss3 = criterion(logits3, label3)
        # loss3.backward()
        # opt.step()
        # preds = logits3.max(dim=1)[1]
        # count += batch_size
        # train_loss += loss3.item() * batch_size
    break
    scheduler.step()
    outstr = 'Train %d, loss: %.6f' % (epoch, train_loss*1.0/count)
    io.cprint(outstr)

    ####################
    # Test
    ####################
    test_loss = 0.0
    count = 0.0
    model.eval()
    test_pred = []
    test_true = []
    for data, label in tqdm(test_loader):
        data, label = data.to(device), label.to(device).squeeze()
        data = data.permute(0, 2, 1)
        batch_size = data.size()[0]
        logits = model(data)
        loss = cal_loss(logits, label)
        preds = logits.max(dim=1)[1]
        count += batch_size
        test_loss += loss.item() * batch_size
        test_true.append(label.cpu().numpy())
        test_pred.append(preds.detach().cpu().numpy())
    test_true = np.concatenate(test_true)
    test_pred = np.concatenate(test_pred)
    test_acc = metrics.accuracy_score(test_true, test_pred)
    avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred)
    if test_acc >= best_test_acc:
        best_test_acc = test_acc
        torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % args.exp_name)
    outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f, best test acc: %.6f' % (epoch,
                                                                            test_loss*1.0/count,
                                                                            test_acc,
                                                                            avg_per_class_acc,
                                                                            best_test_acc)
    io.cprint(outstr)
    






Namespace(batch_size=10, data='MN40', dropout=0.5, emb_dims=1024, epochs=50, eval=False, exp_name='SageMix', k=20, lr=0.0001, model='pointnet', model_path='', momentum=0.9, no_cuda=False, num_points=1024, seed=1, sigma=-1, test_batch_size=16, theta=0.2, use_sgd=True)


  0%|          | 0/984 [00:00<?, ?it/s]

tensor([34, 22,  4,  4, 35], device='cuda:0')
data01: torch.Size([5, 1024, 3])
data2: torch.Size([5, 1024, 3])


  0%|          | 0/984 [00:00<?, ?it/s]


In [62]:
count = 0
for data, label in tqdm(train_loader):
    print(label)
    count +=1
    if (count > 10):
        break

  0%|          | 10/2460 [00:00<02:47, 14.65it/s]

tensor([[20],
        [37],
        [11],
        [ 4]])
tensor([[32],
        [35],
        [ 8],
        [17]])
tensor([[35],
        [14],
        [25],
        [38]])
tensor([[30],
        [ 5],
        [ 4],
        [ 4]])
tensor([[12],
        [25],
        [37],
        [38]])
tensor([[ 1],
        [22],
        [37],
        [ 2]])
tensor([[30],
        [ 8],
        [ 8],
        [30]])
tensor([[15],
        [ 4],
        [36],
        [ 4]])
tensor([[ 7],
        [16],
        [18],
        [ 6]])
tensor([[22],
        [35],
        [18],
        [ 9]])
tensor([[12],
        [ 5],
        [ 4],
        [22]])





In [65]:
import plotly.graph_objects as go
import numpy as np

# Assume you have 'points1' and 'points2' as your point cloud data. 
# The shape of points should be Nx3 (N points in 3D)
points1 = np.random.rand(100, 3)
points2 = np.random.rand(100, 3) + 2  # +2 to separate the two point clouds

# Define the 3D scatter plots
scatter1 = go.Scatter3d(
    x=points1[:, 0],
    y=points1[:, 1],
    z=points1[:, 2],
    mode='markers',
    marker=dict(
        size=4,
        color='red',  # color set to red
        opacity=0.8
    ),
    name='Point cloud 1'  # name displayed in legend
)

scatter2 = go.Scatter3d(
    x=points2[:, 0],
    y=points2[:, 1],
    z=points2[:, 2],
    mode='markers',
    marker=dict(
        size=4,
        color='blue',  # color set to blue
        opacity=0.8
    ),
    name='Point cloud 2'  # name displayed in legend
)

# Define the layout of the plot with desired width and height and hide the axes
layout = go.Layout(
    autosize=False,  # This enables custom size
    width=500,  # Custom width
    height=500,  # Custom height
    scene=dict(
        xaxis=dict(showticklabels=False, visible=False),
        yaxis=dict(showticklabels=False, visible=False),
        zaxis=dict(showticklabels=False, visible=False)
    ),
    margin=dict(l=0, r=0, b=0, t=0)
)

# Create the plot with both scatter plots
fig = go.Figure(data=[scatter1, scatter2], layout=layout)

# Show the plot
fig.show()


In [77]:
import plotly.graph_objects as go
import numpy as np

# Assume you have 'points' as your point cloud data. 
# The shape of points should be Nx3 (N points in 3D)
# points = np.random.rand(100, 3)
# print(data_all[0].shape)

# points = data_all[0].detach().cpu().numpy()
for data, label in tqdm(train_loader):
    points1 = data[0].detach().cpu().numpy()
    points2 = data[1].detach().cpu().numpy()
    points3 = data[2].detach().cpu().numpy()
    points4 = data[3].detach().cpu().numpy()
    break

print(label)
points2 -=  np.array([10, 0, 0])
points3 +=  np.array([4, 0, 0])
# Define the 3D scatter plot
# Define the 3D scatter plots
scatter1 = go.Scatter3d(
    x=points1[:, 0],
    y=points1[:, 1],
    z=points1[:, 2],
    mode='markers',
    marker=dict(
        size=4,
        color='red',  # color set to red
        opacity=0.8
    ),
    name='Point cloud 1'  # name displayed in legend
)

# scatter2 = go.Scatter3d(
#     x=points2[:, 0],
#     y=points2[:, 1],
#     z=points2[:, 2],
#     mode='markers',
#     marker=dict(
#         size=4,
#         color='blue',  # color set to blue
#         opacity=0.8
#     ),
#     name='Point cloud 2'  # name displayed in legend
# )

# scatter3 = go.Scatter3d(
#     x=points3[:, 0],
#     y=points3[:, 1],
#     z=points3[:, 2],
#     mode='markers',
#     marker=dict(
#         size=4,
#         color='green',  # color set to blue
#         opacity=0.8
#     ),
#     name='Point cloud 4'  # name displayed in legend
# )


# Define the layout of the plot with desired width and height and hide the axes
layout = go.Layout(
    autosize=False,  # This enables custom size
    width=1500,  # Custom width
    height=800,  # Custom height
    scene=dict(
        xaxis=dict(showticklabels=False, visible=False),
        yaxis=dict(showticklabels=False, visible=False),
        zaxis=dict(showticklabels=False, visible=False),
        # annotations=[  # add annotations
        #     dict(
        #         showarrow=False,
        #         x=points1[0, 0],  # x position of the annotation
        #         y=points1[0, 1],  # y position of the annotation
        #         z=points1[0, 2],  # z position of the annotation
        #         text="Point cloud 1",  # text of the annotation
        #         xanchor="left",
        #         xshift=10,
        #         opacity=0.7
        #     ),
        #     dict(
        #         showarrow=False,
        #         x=points2[0, 0],  # x position of the annotation
        #         y=points2[0, 1],  # y position of the annotation
        #         z=points2[0, 2],  # z position of the annotation
        #         text="Point cloud 2",  # text of the annotation
        #         xanchor="left",
        #         xshift=10,
        #         opacity=0.7
        #     )
        # ]
    )
)

# Create the plot
fig = go.Figure(data=[scatter1], layout=layout)

# Show the plot
fig.show()


  0%|          | 0/2460 [00:00<?, ?it/s]

tensor([[30],
        [28],
        [ 2],
        [37]])





In [None]:



import torch

def interleave_tensors(a, b):
    # Check that a and b have the same shape
    assert a.shape == b.shape, "Tensors must have the same shape"
    
    # Expand dimensions
    a = a.unsqueeze(1)
    b = b.unsqueeze(1)

    # Concatenate tensors
    c = torch.cat((a, b), dim=1)

    # Reshape tensor
    c = c.view(-1, *a.shape[2:])

    return c

# Test the function
a = torch.randn(10, 1024, 3)
b = torch.randn(10, 1024, 3)

c = interleave_tensors(a, b)

#print("a", a, "b", b, "c", c)  # Should #print: torch.Size([20, 1024, 3])


a tensor([[[-0.1125, -0.0414, -0.3752],
         [-1.3032, -0.2158, -1.2051],
         [-0.5357,  1.5383,  0.3800],
         ...,
         [-0.3512,  0.4020, -1.1674],
         [ 1.1971,  0.3080,  0.6691],
         [ 0.0381, -0.2588,  1.0040]],

        [[-0.3294,  0.4150, -1.1194],
         [-0.0286,  1.3804,  1.0462],
         [ 0.5630,  1.0149,  0.7656],
         ...,
         [ 0.3430,  0.1150,  0.9111],
         [-0.1814,  0.0393,  0.1343],
         [ 1.1760, -0.5602, -1.4493]],

        [[-1.0897, -1.0384,  1.3633],
         [ 0.6443, -1.3132, -0.6551],
         [ 0.4677, -0.4813, -0.2373],
         ...,
         [-0.2463, -0.4392,  1.3484],
         [ 0.3834, -1.0055,  0.0197],
         [ 0.0544, -1.0365, -0.4156]],

        ...,

        [[ 0.5194,  1.4834,  0.7339],
         [ 0.4470, -1.4777, -0.3954],
         [ 0.9724,  2.6624,  1.5264],
         ...,
         [ 2.2803,  0.2474,  1.4016],
         [-0.4626,  0.1937,  0.4863],
         [-2.1015, -1.3753, -0.6788]],

        