In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Image, display
import matplotlib.image as mpimg
from tqdm import tqdm
import math
import os
import json
from types import SimpleNamespace
from foldingdiff.tokenizer import Tokenizer
from collections import defaultdict
from foldingdiff.datasets import *
from foldingdiff.algo import compute_rmsd
os.chdir('/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff')
from bin.encode import BPE

In [None]:
def modified(t):
    mod = []
    for k, v in t.bond_to_token.items():
        if isinstance(v[1], tuple):
            mod.append(k)
    return mod


def compare(t1, t2):
    return compute_rmsd(t1.compute_coords(), t2.compute_coords())


def vis_images(*paths):
    """
    Display an arbitrary number of images in a square-ish grid layout.

    Parameters:
    *paths: variable number of file paths to images
    """
    n = len(paths)
    if n == 0:
        print("No images to display.")
        return

    # Determine grid size (close to square)
    n_cols = math.ceil(math.sqrt(n))
    n_rows = math.ceil(n / n_cols)

    # Create subplots
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 4))

    # Flatten axes array for easy iteration
    if isinstance(axes, plt.Axes):
        axes = [axes]
    else:
        axes = axes.flatten()

    # Display each image
    for ax, path in zip(axes, paths):
        img = mpimg.imread(path)
        ax.imshow(img)
        ax.set_title(path.split("/")[-1])
        ax.axis('off')

    # Hide any unused subplots
    for ax in axes[len(paths):]:
        ax.axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
NO_ITERS = 500
STEP_ITER = 10
ratio = 10
d = "1751601353.45675" # bins 3
# d = "1751601353.4568286" # bins 4
# d = "1751395339.1781707"
# d = "1751395338.979964"


In [None]:


dataset = FullCathCanonicalCoordsDataset(
    'all', use_cache=False, debug=False,
    zero_center=False, toy=1000, pad=512, secondary=False,
    trim_strategy="discard"
)
cleaned_structures = []
for i, struc in enumerate(dataset.structures):
    if (struc['angles']['psi']==struc['angles']['psi']).sum() < len(struc['angles']['psi'])-1:
        print(f"skipping {i}, {struc['fname']} because of missing dihedrals")
    else:
        cleaned_structures.append(struc)
dataset.structures = cleaned_structures
ref = BPE(dataset.structures, 
            bins={1:3}, 
            bin_strategy='uniform', 
            save_dir=f'./ckpts/{d}',
            rmsd_partition_min_size=2,
            num_partitions=10,
            compute_sec_structs=False, 
            plot_iou_with_sec_structs=False,                  
            res_init=True)
ref.initialize()

In [None]:
Ks, Ls, errs = [], [], []
for t in range(0, NO_ITERS, STEP_ITER):
    path = f'./ckpts/{d}/bpe_iter={t}.pkl'
    if not os.path.exists(path):
        break
    bpe = pickle.load(open(path, 'rb'))
    usage = [len(t.bond_to_token) for t in bpe.tokenizers]
    N = len(bpe.tokenizers)
    K = len(bpe._tokens)
    L = np.mean(usage)
    err = np.mean([compare(bpe.tokenizers[i], ref.tokenizers[i]) for i in tqdm(range(min(N, 10)))])
    errs.append(err)
    Ks.append(K)
    Ls.append(L)

Ks = np.array(Ks)
Ls = np.array(Ls)
errs = np.array(errs)
N = len(Ks)

# make figure + first (left) axis
fig, ax1 = plt.subplots(figsize=(8, 5))

# plot L vs K on left y-axis
x_diag = np.linspace(Ks.min(), Ks.max(), 100)
ax1.plot(x_diag, x_diag/ratio, linestyle='--', label=f"L=K (K/L={ratio:.1f})")
ax1.plot(Ks, Ls, marker='o', label="L vs K", linewidth=2)
ax1.set_xlabel("K (Vocab Size) Each Round")
ax1.set_ylabel("L  (#Motif-Tokens Per PDB)")
ax1.set_xticks(Ks)

# create a second y-axis that shares the same x
ax2 = ax1.twinx()
ax2.plot(Ks, errs, marker='x', linestyle=':', label="Error", linewidth=2, color="tab:red")
ax2.set_ylabel("Error", color="tab:red")
ax2.tick_params(axis="y", labelcolor="tab:red")

# combine legends from both axes
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc="best")

ax1.set_title(f"L vs K for N={N} w/ {len(Ks)} BPE rounds")
fig.tight_layout()
plt.show()

In [None]:
t = 0
path = f'./ckpts/{d}/bpe_iter={t}.pkl'
bpe = pickle.load(open(path, 'rb'))


In [None]:
vis_images(f'/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff/ckpts/{d}/tokens_iter=2000.png')

In [None]:
for index in range(len(bpe.tokenizers)):
    t = bpe.tokenizers[index]
    for k in modified(t):
        print(index, k, t.bond_to_token[k])

        
index = 6
t = bpe.tokenizers[index] 

In [None]:
start, length = 69, 6
occur = (30, 8)

In [None]:
path = os.path.abspath('../test.png')
ref_path = os.path.abspath('../ref.png')
bond_path = os.path.abspath('../test_bonds.png')
ref_bond_path = os.path.abspath('../ref_bonds.png')
t.visualize(path)
ref.tokenizers[index].visualize(ref_path)
t.visualize_bonds(start, length, bond_path)
ref.tokenizers[index].visualize_bonds(start, length, ref_bond_path)
vis_images(ref_bond_path, bond_path)
# vis_images(*([bond_path] + [f'./ckpts/{d}/key_iter=0_{i}.png' for i in range(10)]))

In [None]:
vis_images(ref_path, path)

In [None]:
t.token_geo(start, length), bpe._tokens[occur]

In [None]:
ref.tokenizers[index].fname, t.fname

In [None]:
t.bond_to_token

In [None]:
full = t.token_geo(0, 3*t.n-1)
tokenized = []
for (start, bt, length) in t.bond_to_token.values():
    tokenized.append(("MOTIF", bt))
    # find the dihedral
    b = start+length # find dihedral around this bond        
    if b < 3*t.n-1:
        dt = Tokenizer.DIHEDRAL_ANGLES[(b-2)%3]
        tokenized.append(("DIHEDRAL", dt, t._dihedral_angle(b-2)))
        dt = Tokenizer.DIHEDRAL_ANGLES[(b-1)%3]
        tokenized.append(("DIHEDRAL", dt, t._dihedral_angle(b-1)))
        bt = Tokenizer.BOND_ANGLES[(b-1)%3]
        tokenized.append(("BOND_ANGLE", bt, t._bond_angle(b-1)))

In [None]:
repl = defaultdict(list)
for token in tokenized:
    if token[0] == "MOTIF":
        bt = token[1]
        key_dict = bpe._tokens[bt]
        while isinstance(key_dict, str):
            key_dict = json.loads(key_dict)
        for k in key_dict:
            repl[k] += key_dict[k]
    else:
        dt = token[1]
        val = token[2]
        repl[dt].append(val)

repl = dict(repl)

In [None]:
for angle_type in t.BOND_TYPES+t.BOND_ANGLES+t.DIHEDRAL_ANGLES:
    print(np.argwhere(np.array(repl[angle_type]) != np.array(full[angle_type])))

In [None]:
tokenized