# sphereProjection.ipynb

Takes decoded codebook vectors (embeddings) from `./codebooks`, as produced by nH_vqvae.py, and renders two 2D plots for each...

Our spheres are polar, in terms of r, theta, and phi. Seperately, we flatten against theta and phi, creating two plots.

In [1]:
import numpy as np
import pickle
import plotly.graph_objects as go
import plotly.io as pio

In [2]:
num_radii_bins = 5
num_theta_bins = 10
num_phi_bins = 5
total_bins = num_radii_bins * num_theta_bins * num_phi_bins

In [3]:
def data_to_dict(data):
    bin_dict = {}
    i = 0

    for type in ['B','L','R','X']:
        for r in range(num_radii_bins):
            for t in range(num_theta_bins):
                for p in range(num_phi_bins):
                    bin_dict[(type,r,t,p)] = data[i]
                    i += 1
    return bin_dict

Breaking it down in terms of RADII, we have:

'B' 0 for the first 50
'B' 1 for the second 50
...
'B' 4 for the fifth 50

then we have

'L' 0 for the first 50
...
'L' 4 for the fifth 50

In other words, we want to group the quintiles for each quartile.

In [4]:
def flatten_by(index, data):
    combined_dict = {}
    count_dict = {}

    for key, value in data.items():
        
        # New key with the specified index replaced by -1
        new_key = key[:index] + (-1,) + key[index+1:]
        
        if new_key in combined_dict:
            combined_dict[new_key] += value
            count_dict[new_key] += 1
        else:
            combined_dict[new_key] = value
            count_dict[new_key] = 1

    # Average
    for key in combined_dict:
        combined_dict[key] /= count_dict[key]

    return combined_dict

In [5]:
def flatten_phi(data):
    flattened_dict = flatten_by(2, data)
    flattened = []

    for i in range(num_radii_bins):

        radius = []
        for j in range(num_theta_bins):

            for key, value in flattened_dict.items():
                if key[0] == i and key[1] == j:
                    radius.append(value)

        flattened.append(radius)
    return flattened

In [6]:
def flatten_theta(data):
    flattened_dict = flatten_by(1, data)
    flattened = []

    for i in range(num_radii_bins):

        radius = []
        for j in range(num_phi_bins):

            for key, value in flattened_dict.items():
                if key[0] == i and key[2] == j:
                    radius.append(value)

        flattened.append(radius)
    return flattened

In [7]:
def spiderweb_plot(num_circles, num_radials, opacities, line_thickness=2, color=(255, 255, 255)):

    # Skeleton of an arc
    def arc_points(radius, start_angle, end_angle, num_points=50):
        angles = np.linspace(start_angle, end_angle, num_points)
        x = radius * np.cos(angles)
        y = radius * np.sin(angles)
        return x, y

    fig = go.Figure()

    # Draw concentric circles with black lines
    for i in range(1, num_circles + 1):
        fig.add_shape(
            type="circle",
            xref="x",
            yref="y",
            x0=-i,
            y0=-i,
            x1=i,
            y1=i,
            line=dict(color="black", width=line_thickness)
        )

    # Draw radials
    for i in range(num_radials):
        angle = i * 2 * np.pi / num_radials
        x = np.cos(angle) * num_circles
        y = np.sin(angle) * num_circles
        fig.add_shape(
            type="line",
            x0=0,
            y0=0,
            x1=x,
            y1=y,
            line=dict(color="black", width=line_thickness)
        )

    # Add regions with color according to opacity array...
    for i in range(num_circles):
        for j in range(num_radials):
            angle1 = j * 2 * np.pi / num_radials
            angle2 = (j + 1) * 2 * np.pi / num_radials
            r1, r2 = i, i + 1

            opacity = opacities[i, j]
            fillcolor = f'rgba({color[0]},{color[1]},{color[2]},{opacity:.2f})'

            # Generate arc points and combine to make path
            x_inner, y_inner = arc_points(r1, angle1, angle2)
            x_outer, y_outer = arc_points(r2, angle2, angle1)
            x = np.concatenate([x_inner, x_outer])
            y = np.concatenate([y_inner, y_outer])

            fig.add_trace(go.Scatter(
                x=x,
                y=y,
                fill='toself',
                fillcolor=fillcolor,
                line=dict(color='rgba(0,0,0,0)'),
                mode='lines'
            ))

    fig.update_layout(
        xaxis=dict(visible=False, range=[-num_circles, num_circles]),
        yaxis=dict(visible=False, range=[-num_circles, num_circles]),
        yaxis_scaleanchor="x",
        showlegend=False,
        width=600,
        height=600,
        margin=dict(l=0, r=0, t=0, b=0)
    )

    return fig


In [8]:
def appleslice_plot(num_circles, num_radials, opacities, line_thickness=2, color=(255, 255, 255)):

    def arc_points(radius, start_angle, end_angle, num_points=50):
        angles = np.linspace(start_angle, end_angle, num_points)
        x = radius * np.cos(angles)
        y = radius * np.sin(angles)
        return x, y

    fig = go.Figure()

    for i in range(1, num_circles + 1):
        x, y = arc_points(i, -np.pi/2, np.pi/2)
        fig.add_trace(go.Scatter(
            x=x,
            y=y,
            mode='lines',
            line=dict(color="black", width=line_thickness),
            showlegend=False
        ))

    for i in range(num_radials + 1):
        angle = -np.pi/2 + i * np.pi / num_radials
        x = np.array([0, np.cos(angle) * num_circles])
        y = np.array([0, np.sin(angle) * num_circles])
        fig.add_trace(go.Scatter(
            x=x,
            y=y,
            mode='lines',
            line=dict(color="black", width=line_thickness),
            showlegend=False
        ))

    for i in range(num_circles):
        for j in range(num_radials):
            angle1 = -np.pi/2 + j * np.pi / num_radials
            angle2 = -np.pi/2 + (j + 1) * np.pi / num_radials
            r1, r2 = i, i + 1

            opacity = opacities[i, j]
            fillcolor = f'rgba({color[0]},{color[1]},{color[2]},{opacity:.2f})'

            x_inner, y_inner = arc_points(r1, angle1, angle2)
            x_outer, y_outer = arc_points(r2, angle2, angle1)
            x = np.concatenate([x_inner, x_outer])
            y = np.concatenate([y_inner, y_outer])

            fig.add_trace(go.Scatter(
                x=x,
                y=y,
                fill='toself',
                fillcolor=fillcolor,
                line=dict(color='rgba(0,0,0,0)'),
                mode='lines'
            ))

    fig.update_layout(
        xaxis=dict(visible=False, range=[0, num_circles]),  # Trim the left half...
        yaxis=dict(visible=False, range=[-num_circles, num_circles]),
        yaxis_scaleanchor="x",
        showlegend=False,
        width=600,
        height=600,
        margin=dict(l=0, r=0, t=0, b=0)
    )

    return fig

In [9]:
def trim_keys(dict):
    trimmed_dict = {}
    for key, value in dict.items():
        new_key = key[1:]
        trimmed_dict[new_key] = value
    return trimmed_dict

In [10]:
# Min-max normalization
def normalize(arr, min, max):
    return (arr - min) / (max - min)

In [11]:
def view(items, color, category):
    thetass = [flatten_theta(embedding) for embedding in items]
    phiss = [flatten_phi(embedding) for embedding in items]

    # We use these for normalization... mins and maxes are taken from the dataset AFTER flattening
    t_min, t_max = np.min(thetass), np.max(thetass)
    p_min, p_max = np.min(phiss), np.max(phiss)

    thetas_opacities = [normalize(embedding, t_min, t_max) for embedding in thetass]
    phis_opacities = [normalize(embedding, p_min, p_max) for embedding in phiss]

    for i in range(20):
        name = category + str(i)

        sp = spiderweb_plot(num_radii_bins, num_theta_bins, phis_opacities[i], 0.5, color)
        ap = appleslice_plot(num_radii_bins, num_phi_bins, thetas_opacities[i], 0.5, color)

        pio.write_image(sp, f'plots/embeddings/{name}_phi.png')
        pio.write_image(ap, f'plots/embeddings/{name}_theta.png')

In [12]:
def visualize(codebook):
    binned_codebooks = [data_to_dict(codebook[i]) for i in range(20)] # array of dicts for each codebook

    B = []
    R = []
    L = []
    X = []

    for book in binned_codebooks:

        restricted_B = {key: book[key] for key in book.keys() if key[0] == 'B'}
        B.append(restricted_B)

        restricted_R = {key: book[key] for key in book.keys() if key[0] == 'R'}
        R.append(restricted_R)

        restricted_L = {key: book[key] for key in book.keys() if key[0] == 'L'}
        L.append(restricted_L)

        restricted_X = {key: book[key] for key in book.keys() if key[0] == 'X'}
        X.append(restricted_X)

    averaged_total = [trim_keys(flatten_by(0, binned_codebooks[i])) for i in range(20)] # the flatten_by averages over the four types
    B = [trim_keys(flatten_by(0, B[i])) for i in range(20)]
    R = [trim_keys(flatten_by(0, R[i])) for i in range(20)]
    L = [trim_keys(flatten_by(0, L[i])) for i in range(20)]
    X = [trim_keys(flatten_by(0, X[i])) for i in range(20)]

    view(averaged_total, (147, 81, 22), '')
    view(B, (31, 135, 232), 'B') # beta-sheet
    view(R, (211, 38, 38), 'R') # right-helix
    view(L, (58, 211, 38), 'L') # left-helix
    view(X, (241, 196, 15), 'X') # uncharted territory
        

In [13]:
with open('codebooks/mse_codebook.pkl', 'rb') as f:
    codebook = pickle.load(f)

visualize(codebook)