In [101]:
%matplotlib tk

import sys
import re
from os.path import exists

import numpy as np
import pandas as pd
import starfile as sf

from sklearn.cluster import AgglomerativeClustering
import scipy.interpolate as spin

import matplotlib as mpl
import matplotlib.pyplot as plt

# np.set_printoptions(threshold=sys.maxsize)

## Function definitions

In [181]:
def get_ribo_from_star(star_file):
    pixel_size_nm = sf.read(star_file)['optics'].rlnImagePixelSize.values[0] * 0.1
    ribo_star = sf.read(star_file)
    ribo_star['particles']['rlnTS'] = [int(i.split('/')[1].split('_')[-1]) for i in list(ribo_star['particles'].rlnImageName.values)]
    TS_list = pd.unique(ribo_star['particles'].rlnTS)
    
    return ribo_star, TS_list, pixel_size_nm

In [184]:
def get_coords(star_df_in, TS, model_bin, star_bin):
    ribo = star_df_in[star_df_in.rlnTS==TS][['rlnCoordinateX', 'rlnCoordinateY', 'rlnCoordinateZ']].to_numpy() * star_bin / model_bin
    
    return ribo

In [172]:
def get_model(model_file):
    model = np.loadtxt(model_file)
    model_size = len(model)
    model = model[model[:,2].argsort()]
    
    return model

In [173]:
def segment_surfaces(model_in):
    ac = AgglomerativeClustering(n_clusters=2, linkage="single")
    ac.fit(model)

    labels = ac.labels_
    model_lower = np.squeeze(model[np.argwhere(labels==0)], axis=1)
    model_upper = np.squeeze(model[np.argwhere(labels==1)], axis=1)
    
    return labels, model_lower, model_upper

In [174]:
def interpolator(coords_in, upper_in, lower_in, N):
    x_top = np.linspace(np.min(upper_in[:,0]), np.max(upper_in[:,0]), N)
    y_top = np.linspace(np.min(upper_in[:,1]), np.max(upper_in[:,1]), N)
    XX, YY = np.meshgrid(x_top, y_top)

    x_bot = np.linspace(np.min(lower_in[:,0]), np.max(lower_in[:,0]), N)
    y_bot = np.linspace(np.min(lower_in[:,1]), np.max(lower_in[:,1]), N)
    xx, yy = np.meshgrid(x_bot, y_bot)


    itp_top = spin.LinearNDInterpolator(list(zip(upper_in[:,0], upper_in[:,1])), upper_in[:,2])
    itp_bot = spin.LinearNDInterpolator(list(zip(lower_in[:,0], lower_in[:,1])), lower_in[:,2])
    ZZ = itp_top(XX, YY)
    zz = itp_bot(xx, yy)

    interped_top = np.dstack((XX, YY, ZZ))
    interped_bot = np.dstack((xx, yy, zz))
    
    to_edge = np.empty((len(coords_in), 2))
    for idx, point in enumerate(coords_in):
        to_edge[idx] = [np.nanmin(np.linalg.norm(interped_top - point, axis=2)), np.nanmin(np.linalg.norm(interped_bot - point, axis=2))]
    
    return interped_top, interped_bot, to_edge

## Getting model / ribosome coords

In [179]:
model_bin = 8
star_bin = 2

model_file_format = "Position_<TS>_bin8_filtered2.txt"

In [182]:
star_file = "ribosomes/bin2_postM_refinement_conv.star"
ribo_star, TS_list, pixel_size_nm = get_ribo_from_star(star_file)

## Loop through list of TS

In [185]:
for _, curr_ts in enumerate(TS_list):
    model_file = re.sub("<TS>", str(curr_ts), model_file_format)
    try: 
        assert(exists(model_file))
    except:
        print(f"WARNING: {model_file} doesn't exist. TS{curr_ts} skipped.")
        continue
        
    ribo = get_coords(ribo_star['particles'], curr_ts, model_bin, star_bin)
    model = get_model(model_file)

    #     Segmentation of surfaces
    labels, model_upper, model_lower = segment_surfaces(model)
    
    #     Plane interpolation
    interped_top, interped_bot, to_edge = interpolator(ribo, model_upper, model_lower, 100)
    
    #     Aggregation of data
    df = pd.DataFrame(columns=["x", "y", "z", "to_top", "to_bottom"])
    df.x = ribo[:, 0]
    df.y = ribo[:, 1]
    df.z = ribo[:, 2]
    df.to_top = to_edge[:, 0]
    df.to_bottom = to_edge[:, 1]
    df["to_any_edge"] = df[["to_top", "to_bottom"]].values.min(1)
    
#     Update of star-DataFrame
    ribo_star['particles'].loc[ribo_star['particles'].rlnTS==curr_ts, 'rlnDistToEdge_nm'] = df.to_any_edge.to_numpy() * pixel_size_nm * model_bin / star_bin

In [187]:
pd.set_option('display.max_rows', 10)
ribo_star['particles'][~ribo_star['particles'].rlnDistToEdge_nm.isnull()]

Unnamed: 0,rlnCoordinateX,rlnCoordinateY,rlnCoordinateZ,rlnAngleRot,rlnAngleTilt,rlnAnglePsi,rlnImageName,rlnCtfImage,rlnRandomSubset,rlnOpticsGroup,rlnTS,rlnDistToEdge_nm
8118,1337.2490,400.4035,535.7270,126.46280,20.26509,-97.67905,Particles/Position_36/Position_36_ribo06_00000...,Particles/Position_36/Position_36_ribo06_00000...,2,1,36,174.940150
8119,1784.5280,1884.7200,457.0119,-11.46472,167.08530,59.68486,Particles/Position_36/Position_36_ribo06_00000...,Particles/Position_36/Position_36_ribo06_00000...,1,1,36,183.381995
8120,1313.0810,825.5162,861.7173,-174.27860,22.51733,-69.87370,Particles/Position_36/Position_36_ribo06_00000...,Particles/Position_36/Position_36_ribo06_00000...,2,1,36,59.474072
8121,1617.4870,190.4373,189.2904,-68.75554,148.56390,-117.83550,Particles/Position_36/Position_36_ribo06_00000...,Particles/Position_36/Position_36_ribo06_00000...,1,1,36,65.112862
8122,1188.8830,460.2473,817.0438,163.07020,64.31494,145.93260,Particles/Position_36/Position_36_ribo06_00000...,Particles/Position_36/Position_36_ribo06_00000...,2,1,36,86.963044
...,...,...,...,...,...,...,...,...,...,...,...,...
13375,1001.0370,1448.7940,371.2605,68.80367,80.41988,129.98090,Particles/Position_99/Position_99_ribo06_00005...,Particles/Position_99/Position_99_ribo06_00005...,2,1,99,37.302906
13376,794.2349,1206.1050,420.6743,-153.80690,113.10310,82.39798,Particles/Position_99/Position_99_ribo06_00005...,Particles/Position_99/Position_99_ribo06_00005...,1,1,99,32.425480
13377,1177.6690,203.5547,395.1078,-90.39685,142.08800,93.44049,Particles/Position_99/Position_99_ribo06_00005...,Particles/Position_99/Position_99_ribo06_00005...,2,1,99,51.280616
13378,926.5816,1121.4620,506.4562,21.47708,84.77400,-162.28090,Particles/Position_99/Position_99_ribo06_00005...,Particles/Position_99/Position_99_ribo06_00005...,1,1,99,74.079435


## Write out to star file

In [188]:
new_star_path = "mystar.star"

sf.write(ribo_star, new_star_path, overwrite=True)

## Visualisation

In [93]:
N = 50

closest_points = df.sort_values(by="to_top").head(N)

In [31]:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(projection="3d")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")

ax.plot_trisurf(model_lower[:,0], model_lower[:,1], model_lower[:,2], color='b', antialiased=True)
ax.plot_trisurf(model_upper[:,0], model_upper[:,1], model_upper[:,2], color='g', antialiased=True)
ax.scatter(ribo[:,0], ribo[:,1], ribo[:,2], c=df.to_any_edge, cmap="inferno_r")
plt.tight_layout()

In [14]:
closest_points

Unnamed: 0,x,y,z,to_top,to_bottom,to_any_edge
188,216.67385,89.94135,228.71615,10.703467,137.452438,10.703467
189,210.675425,382.15925,230.072075,12.712387,143.912875,12.712387
57,172.26365,269.42875,242.20965,12.801857,145.673725,12.801857
247,285.45825,419.75925,206.98305,13.482779,140.743655,13.482779
82,154.111425,55.98195,241.11425,13.918548,131.515485,13.918548
279,264.40025,458.70575,210.229725,14.16894,135.552078,14.16894
396,140.872075,426.53025,247.29885,14.806094,136.340569,14.806094
273,78.655275,335.08375,264.4765,15.941909,143.036335,15.941909
255,288.0125,167.65905,206.3933,16.0159,136.614513,16.0159
538,196.76385,227.34785,232.339275,16.052,140.183574,16.052
