In [1]:
# This notebook demonstrates how to calculate the shape score to classify the protein folds ZMPY3D with PyTorch.
#
# The original method involved (https://doi.org/10.1371/journal.pcbi.1007970)
#     1) a trimming procedure to preprocess voxels by removing those far from the center and
#     2) a dynamic grid width for individual structures.
#
# These non-essential preprocessing techniques can be improved and are unrelated to this application note's
# focus on high-performance GPU-based 3D Zernike moments.
#
#
# This notebook primarily consists of the following steps:
#     1. Install ZMPY3D_PT.
#     2. Define necessary parameters.
#     3. Load precalculated cache.
#     4. Download example PDB data with coordinates.
#     5. Convert coordinate data into a voxel.
#     6. Create a callable function for generating Zernike moments and normalization.
#     7. Obtain the results.
#     8. A command line interface (CLI) example

In [None]:
# Install ZMPY3D versions for PyTorch.

! pip install ZMPY3D-PT

print(f"It is recommended to restart the Python kernel for the IPython notebook.")

In [None]:
# Download example data from GitHub using curl
! curl -OJL https://github.com/tawssie/ZMPY3D/raw/main/1WAC_A.txt
! curl -OJL https://github.com/tawssie/ZMPY3D/raw/main/2JL9_A.txt


In [None]:
import ZMPY3D_PT as z
import torch
import numpy as np
import pickle

if torch.cuda.is_available():
    torch.set_default_device('cuda')
    print("CUDA is available, PyTorch uses GPU.")
else:
    print("CUDA is unavailable. PyTorch uses CPU.")

MaxOrder = 20
GridWidth= 1.00
Param=z.get_global_parameter()

# Find the cache_data directory based on the site package location of ZMPY3D_PT.
BinomialCacheFilePath=z.__file__.replace('__init__.py', 'cache_data') + '/BinomialCache.pkl'
with open(BinomialCacheFilePath, 'rb') as file:
    BinomialCachePKL = pickle.load(file)

LogCacheFilePath=z.__file__.replace('__init__.py', 'cache_data') + '/LogG_CLMCache_MaxOrder{:02d}.pkl'.format(MaxOrder)

with open(LogCacheFilePath, 'rb') as file:
    CachePKL = pickle.load(file)

BinomialCache = torch.tensor(BinomialCachePKL['BinomialCache'], dtype=torch.float64)

GCache_pqr_linear = torch.tensor(CachePKL['GCache_pqr_linear'])
GCache_complex = torch.tensor(CachePKL['GCache_complex'])
GCache_complex_index = torch.tensor(CachePKL['GCache_complex_index'])
CLMCache3D = torch.tensor(CachePKL['CLMCache3D'], dtype=torch.complex128)
CLMCache = torch.tensor(CachePKL['CLMCache'], dtype=torch.float64)

RotationIndex=CachePKL['RotationIndex']

s_id = torch.tensor(np.squeeze(RotationIndex['s_id'][0,0]) - 1, dtype=torch.int64)
n    = torch.tensor(np.squeeze(RotationIndex['n'][0,0]), dtype=torch.int64)
l    = torch.tensor(np.squeeze(RotationIndex['l'][0,0]), dtype=torch.int64)
m    = torch.tensor(np.squeeze(RotationIndex['m'][0,0]), dtype=torch.int64)
mu   = torch.tensor(np.squeeze(RotationIndex['mu'][0,0]), dtype=torch.int64)
k    = torch.tensor(np.squeeze(RotationIndex['k'][0,0]), dtype=torch.int64)
IsNLM_Value = torch.tensor(np.squeeze(RotationIndex['IsNLM_Value'][0,0]) - 1, dtype=torch.int64)

print(f"Now using the MaxOrder of {MaxOrder} and the GridWidth of {GridWidth}.")
print(f"Pre-calculated parameters have been loaded successfully.")


In [None]:
%%time

PDBFileName='./1WAC_A.txt'

# Convert structure data into coordinates
[XYZ,AA_NameList]=z.get_pdb_xyz_ca(PDBFileName)
# Convert coordinates into voxels using precalculated Gaussian densities
ResidueBox=z.get_residue_gaussian_density_cache(Param)
[Voxel3D,Corner]=z.fill_voxel_by_weight_density(XYZ,AA_NameList,Param['residue_weight_map'],GridWidth,ResidueBox[GridWidth])

# Convert the voxel data into a PyTorch object
Voxel3D=torch.tensor(Voxel3D,dtype=torch.float64)

print(f"Converting PDB to 3D voxel grid with NumPy on CPU, then transferring to GPU memory as PyTorch objects.")




In [None]:
%%time

Dimension_BBox_scaled=Voxel3D.shape
MaxOrder=torch.tensor(MaxOrder,dtype=torch.int64)

X_sample = torch.arange(0, Dimension_BBox_scaled[0] + 1, dtype=torch.float64)
Y_sample = torch.arange(0, Dimension_BBox_scaled[1] + 1, dtype=torch.float64)
Z_sample = torch.arange(0, Dimension_BBox_scaled[2] + 1, dtype=torch.float64)

# Calculate the volume mass and the center of mass
[VolumeMass,Center,_]=z.calculate_bbox_moment(Voxel3D,1,X_sample,Y_sample,Z_sample)

# Calculate the weights for sphere sampling
[AverageVoxelDist2Center,MaxVoxelDist2Center]=z.calculate_molecular_radius(Voxel3D,Center,VolumeMass,1.80) # Param['default_radius_multiplier'] == 1.80

# Apply weights to the geometric moments
Sphere_X_sample, Sphere_Y_sample, Sphere_Z_sample=z.get_bbox_moment_xyz_sample(Center,AverageVoxelDist2Center,Dimension_BBox_scaled)

_,_,SphereBBoxMoment=z.calculate_bbox_moment(Voxel3D
                                  ,MaxOrder
                                  ,Sphere_X_sample
                                  ,Sphere_Y_sample
                                  ,Sphere_Z_sample)

# Convert to scaled 3D Zernike moments
ZMoment_scaled,ZMoment_raw=z.calculate_bbox_moment_2_zm(MaxOrder
                                    , GCache_complex
                                    , GCache_pqr_linear
                                    , GCache_complex_index
                                    , CLMCache3D
                                    , SphereBBoxMoment)

ZMList = []

# Convert the scaled 3D Zernike moments into 3DZD-based descriptors
ZM_3DZD_invariant=z.get_3dzd_121_descriptor(ZMoment_scaled)

ZMList.append(z.get_3dzd_121_descriptor(ZMoment_scaled))

# Calculate alternative 3D Zernike moments for specific normalisation orders 2, 3, 4, and 5
MaxTargetOrder2NormRotate=5

for TargetOrder2NormRotate in range(2, MaxTargetOrder2NormRotate+1):
    ABList=z.calculate_ab_rotation(ZMoment_raw, TargetOrder2NormRotate)
    ZM=z.calculate_zm_by_ab_rotation(ZMoment_raw, BinomialCache, ABList, MaxOrder, CLMCache,s_id,n,l,m,mu,k,IsNLM_Value)
    ZM_mean, _ = z.get_mean_invariant(ZM)
    ZMList.append(ZM_mean)

# MomentInvariant is a vector that describes shape information
MomentInvariant = torch.cat([torch.flatten(z[~torch.isnan(z)]) for z in ZMList], dim=0)

TotalResidueWeight=z.get_total_residue_weight(AA_NameList,Param['residue_weight_map'])
TotalResidueWeight=torch.tensor(TotalResidueWeight,dtype=torch.float64)

[Prctile_list,STD_XYZ_dist2center,S,K]=z.get_ca_distance_info(torch.tensor(XYZ))

# GeoDescriptor contains geometric information derived from coordinates
GeoDescriptor = torch.cat([
    AverageVoxelDist2Center.flatten(),
    TotalResidueWeight.flatten(),
    Prctile_list.flatten(),
    STD_XYZ_dist2center.flatten(),
    S.flatten(),
    K.flatten()
], dim=0)


print(f"Transforming the gridded voxel into 3D Zernike moments with global normalization, 3DZD style, yields 121 descriptors.")
print(f"Additional normalization is also calculated at targets 2, 3, 4, and 5, for Zernike moment descriptor.")
print(f"Geometric information is calculated and collected as a geometric descriptor based on the coordinates.")
print(f"The dimensions of the voxel being used are {Voxel3D.shape}.")
print(f"Time elapsed is as follows:")


In [None]:
def OneTimeConversion_PT(XYZ,AA_NameList,Voxel3D,MaxOrder):

    Dimension_BBox_scaled=Voxel3D.shape
    MaxOrder=torch.tensor(MaxOrder,dtype=torch.int64)

    X_sample = torch.arange(0, Dimension_BBox_scaled[0] + 1, dtype=torch.float64)
    Y_sample = torch.arange(0, Dimension_BBox_scaled[1] + 1, dtype=torch.float64)
    Z_sample = torch.arange(0, Dimension_BBox_scaled[2] + 1, dtype=torch.float64)

    [VolumeMass,Center,_]=z.calculate_bbox_moment(Voxel3D,1,X_sample,Y_sample,Z_sample)

    [AverageVoxelDist2Center,MaxVoxelDist2Center]=z.calculate_molecular_radius(Voxel3D,Center,VolumeMass,1.80) # Param['default_radius_multiplier'] == 1.80

    Sphere_X_sample, Sphere_Y_sample, Sphere_Z_sample=z.get_bbox_moment_xyz_sample(Center,AverageVoxelDist2Center,Dimension_BBox_scaled)

    _,_,SphereBBoxMoment=z.calculate_bbox_moment(Voxel3D
                                      ,MaxOrder
                                      ,Sphere_X_sample
                                      ,Sphere_Y_sample
                                      ,Sphere_Z_sample)

    ZMoment_scaled,ZMoment_raw=z.calculate_bbox_moment_2_zm(MaxOrder
                                        , GCache_complex
                                        , GCache_pqr_linear
                                        , GCache_complex_index
                                        , CLMCache3D
                                        , SphereBBoxMoment)

    ZMList = []

    ZM_3DZD_invariant=z.get_3dzd_121_descriptor(ZMoment_scaled)

    ZMList.append(z.get_3dzd_121_descriptor(ZMoment_scaled))

    MaxTargetOrder2NormRotate=5

    for TargetOrder2NormRotate in range(2, MaxTargetOrder2NormRotate+1):
        ABList=z.calculate_ab_rotation(ZMoment_raw, TargetOrder2NormRotate)
        ZM=z.calculate_zm_by_ab_rotation(ZMoment_raw, BinomialCache, ABList, MaxOrder, CLMCache,s_id,n,l,m,mu,k,IsNLM_Value)
        ZM_mean, _ = z.get_mean_invariant(ZM)
        ZMList.append(ZM_mean)

    MomentInvariant = torch.cat([torch.flatten(z[~torch.isnan(z)]) for z in ZMList], dim=0)

    TotalResidueWeight=z.get_total_residue_weight(AA_NameList,Param['residue_weight_map'])
    TotalResidueWeight=torch.tensor(TotalResidueWeight,dtype=torch.float64)

    [Prctile_list,STD_XYZ_dist2center,S,K]=z.get_ca_distance_info(torch.tensor(XYZ))

    GeoDescriptor = torch.cat([
        AverageVoxelDist2Center.flatten(),
        TotalResidueWeight.flatten(),
        Prctile_list.flatten(),
        STD_XYZ_dist2center.flatten(),
        S.flatten(),
        K.flatten()
    ], dim=0)


    return MomentInvariant, GeoDescriptor


print(f"Merge all steps into a single callable PyTorch function, OneTimeConversion_PT.")


In [None]:
%%time

ResidueBox=z.get_residue_gaussian_density_cache(Param)

# These weights and indexes are predefined in the paper available at https://doi.org/10.1371/journal.pcbi.1007970.
P=z.get_descriptor_property()
ZMIndex = torch.cat([
    torch.tensor(P['ZMIndex0']),
    torch.tensor(P['ZMIndex1']),
    torch.tensor(P['ZMIndex2']),
    torch.tensor(P['ZMIndex3']),
    torch.tensor(P['ZMIndex4'])
], dim=0)

ZMWeight = torch.cat([
    torch.tensor(P['ZMWeight0']),
    torch.tensor(P['ZMWeight1']),
    torch.tensor(P['ZMWeight2']),
    torch.tensor(P['ZMWeight3']),
    torch.tensor(P['ZMWeight4'])
], dim=0)

# Transforming coordinates into voxels
PDBFileName_A='./1WAC_A.txt'
[XYZ_A,AA_NameList_A]=z.get_pdb_xyz_ca(PDBFileName_A)
[Voxel3D_A,Corner_A]=z.fill_voxel_by_weight_density(XYZ_A,AA_NameList_A,Param['residue_weight_map'],GridWidth,ResidueBox[GridWidth])
Voxel3D_A=torch.tensor(Voxel3D_A,dtype=torch.float64)

PDBFileName_B='./2JL9_A.txt'
[XYZ_B,AA_NameList_B]=z.get_pdb_xyz_ca(PDBFileName_B)
[Voxel3D_B,Corner_B]=z.fill_voxel_by_weight_density(XYZ_B,AA_NameList_B,Param['residue_weight_map'],GridWidth,ResidueBox[GridWidth])
Voxel3D_B=torch.tensor(Voxel3D_B,dtype=torch.float64)

# Retrieve descriptors
MomentInvariantRawA, GeoDescriptorA=OneTimeConversion_PT(XYZ_A,AA_NameList_A,Voxel3D_A,20)
MomentInvariantRawB, GeoDescriptorB=OneTimeConversion_PT(XYZ_B,AA_NameList_B,Voxel3D_B,20)


# Computing scores using normalized Zernike moments with specified weights and indices

ZMScore = torch.sum(torch.abs(MomentInvariantRawA[ZMIndex] - MomentInvariantRawB[ZMIndex]) * ZMWeight)

# Computing scores from coordinates using specified weights and indices
GeoWeight = torch.tensor(P['GeoWeight'], dtype=torch.float64).flatten()
GeoScore = torch.sum(GeoWeight * (2 * torch.abs(GeoDescriptorA - GeoDescriptorB) / (1 + torch.abs(GeoDescriptorA) + torch.abs(GeoDescriptorB))))


# Rescale scores, where the score is a metric relative to a predefined threshold, not a statistic
GeoScoreScaled = (6.6 - GeoScore) / 6.6 * 100.0
ZMScoreScaled = (9.0 - ZMScore) / 9.0 * 100.0


print(f"Converting two PDB files 1WAC_A.txt and 2JL9_A.txt to gridded voxels, transferring to GPU memory as PyTorch objects.")
print(f"Calculate all descriptors for two structures at GridWidth {GridWidth}, deriving the similarity scores.")

print(f'GeoScore {GeoScoreScaled:.2f} TotalZMScore {ZMScoreScaled:.2f}')
print(f"Time elapsed is as follows:")

In [None]:
# Alternatively, use a system call to compute results via CLI
# ./ZMPY3D_PT_CLI_ShapeScore PDB_A PDB_B GridWidth
! ZMPY3D_PT_CLI_ShapeScore "./1WAC_A.txt" "./2JL9_A.txt" 1.0
