In [60]:
import pickle
import torch

# 加载保存的数据
with open('./medical_data.pkl', 'rb') as file:
    loaded_data = pickle.load(file)

def shape(data):
    for key, tensor_list in data.items():
        print(f"Key: {key}")
        for idx, tensor in enumerate(tensor_list):
            print(f"Tensor {idx + 1} shape: {tensor.shape}")

def cartesian_to_spherical(coords):
    x, y, z = coords[:, 0], coords[:, 1], coords[:, 2]
    
    r = torch.sqrt(x**2 + y**2 + z**2)
    theta = torch.acos(z / r)
    phi = torch.atan2(y, x)

    spherical_coords = torch.stack((r, theta, phi), dim=-1)
    return spherical_coords

def cartisian_angle(tensor):
    '''计算每三个连续坐标构成的角度'''
    num = tensor.size(1)-2
    angles = torch.zeros(tensor.size(0), num)  # 创建一个用于存储结果的张量

    for i in range(num):  # 四个夹角
        vec1 = tensor[:, i, :] - tensor[:, i+1, :]
        vec2 = tensor[:, i+1, :] - tensor[:, i+2, :]

        dot_product = (vec1 * vec2).sum(dim=1)
        norm1 = torch.norm(vec1, dim=1)
        norm2 = torch.norm(vec2, dim=1)
        cos_angle = dot_product / (norm1 * norm2)
        angles[:, i] = torch.acos(cos_angle.clamp(-1, 1))
    return angles

def spherical_tensor(tensor):
    # 获取数据形状
    batch_size, num_points, _ = tensor.size()

    # 创建一个与数据相同形状的结果张量
    result = torch.zeros(batch_size, num_points, 3)

    # 遍历每个元素并进行转换
    for i in range(num_points):
        x, y, z = tensor[:, i, 0], tensor[:, i, 1], tensor[:, i, 2]
        r, theta, phi = cartesian_to_spherical(x, y, z)
        result[:, i, 0] = r
        result[:, i, 1] = theta
        result[:, i, 2] = phi
    return result


# return the left arm as spherical data, where each joint is calculated on the coordinate of last joint
def extract_left_arm_arithmatic(original_data):
    keep_indices = [5, 6, 7, 8, 22]
    for key, tensor_list in original_data.items():
        for idx, tensor in enumerate(tensor_list):
            extracted = tensor[:,keep_indices,:]
            angle_arithmatic = torch.zeros(extracted.size(0), extracted.size(1))
            relative_pos = extracted[:,1,:] - extracted[:,0,:]
            spherical_coords = cartesian_to_spherical(relative_pos)
            angle_arithmatic[:,0] = spherical_coords[:,2] # 肩膀位置映射1：方位角
            angle_arithmatic[:,1] = spherical_coords[:,1] # 肩膀位置映射2：天顶角
            angles_left = cartisian_angle(extracted) # 其他关节角
            angle_arithmatic[:,2:] = angles_left
            original_data[key][idx] = angle_arithmatic
    return original_data

spherical_data = extract_left_arm_arithmatic(loaded_data)
shape(spherical_data)

Key: A041
Tensor 1 shape: torch.Size([109, 5])
Tensor 2 shape: torch.Size([62, 5])
Tensor 3 shape: torch.Size([101, 5])
Tensor 4 shape: torch.Size([72, 5])
Tensor 5 shape: torch.Size([76, 5])
Tensor 6 shape: torch.Size([93, 5])
Tensor 7 shape: torch.Size([68, 5])
Tensor 8 shape: torch.Size([85, 5])
Tensor 9 shape: torch.Size([67, 5])
Tensor 10 shape: torch.Size([76, 5])
Tensor 11 shape: torch.Size([105, 5])
Tensor 12 shape: torch.Size([66, 5])
Tensor 13 shape: torch.Size([85, 5])
Tensor 14 shape: torch.Size([69, 5])
Tensor 15 shape: torch.Size([76, 5])
Tensor 16 shape: torch.Size([74, 5])
Tensor 17 shape: torch.Size([109, 5])
Tensor 18 shape: torch.Size([63, 5])
Tensor 19 shape: torch.Size([101, 5])
Tensor 20 shape: torch.Size([73, 5])
Tensor 21 shape: torch.Size([75, 5])
Tensor 22 shape: torch.Size([94, 5])
Tensor 23 shape: torch.Size([67, 5])
Tensor 24 shape: torch.Size([85, 5])
Tensor 25 shape: torch.Size([67, 5])
Tensor 26 shape: torch.Size([76, 5])
Tensor 27 shape: torch.Size([105