In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Image, display
import matplotlib.image as mpimg
from tqdm import tqdm
import math
import os
import json
from types import SimpleNamespace
from foldingdiff.tokenizer import Tokenizer
from collections import defaultdict
from foldingdiff.datasets import *
from foldingdiff.algo import compute_rmsd
os.chdir('/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff')
from bin.encode import BPE

In [None]:
def modified(t):
    mod = []
    for k, v in t.bond_to_token.items():
        if isinstance(v[1], tuple):
            mod.append(k)
    return mod


def compare(t1, t2):
    return compute_rmsd(t1.compute_coords(), t2.compute_coords())


def vis_images(*paths):
    """
    Display an arbitrary number of images in a square-ish grid layout.

    Parameters:
    *paths: variable number of file paths to images
    """
    n = len(paths)
    if n == 0:
        print("No images to display.")
        return

    # Determine grid size (close to square)
    n_cols = math.ceil(math.sqrt(n))
    n_rows = math.ceil(n / n_cols)

    # Create subplots
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 4))

    # Flatten axes array for easy iteration
    if isinstance(axes, plt.Axes):
        axes = [axes]
    else:
        axes = axes.flatten()

    # Display each image
    for ax, path in zip(axes, paths):
        img = mpimg.imread(path)
        ax.imshow(img)
        ax.set_title(path.split("/")[-1])
        ax.axis('off')

    # Hide any unused subplots
    for ax in axes[len(paths):]:
        ax.axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
NO_ITERS = 500
STEP_ITER = 10
ratio = 10
d = "1751936564.1540673" # bins 3
# d = "1751601353.4568286" # bins 4
# d = "1751395339.1781707"
# d = "1751395338.979964"
args = open(f"./ckpts/{d}/args.txt").readlines()
for line in args:
    print(line.rstrip('\n'))


In [None]:

dataset = FullCathCanonicalCoordsDataset(
    'repeat', use_cache=False, debug=False,
    zero_center=False, toy=0, pad=512, secondary=False,
    trim_strategy="discard"
)
cleaned_structures = []
for i, struc in enumerate(dataset.structures):
    if (struc['angles']['psi']==struc['angles']['psi']).sum() < len(struc['angles']['psi'])-1:
        print(f"skipping {i}, {struc['fname']} because of missing dihedrals")
    else:
        cleaned_structures.append(struc)
dataset.structures = cleaned_structures
ref = BPE(dataset.structures, 
            bins={1:3}, 
            bin_strategy='uniform', 
            save_dir=f'./ckpts/{d}',
            rmsd_partition_min_size=2,
            num_partitions=10,
            compute_sec_structs=False, 
            plot_iou_with_sec_structs=False,                  
            res_init=True)
ref.initialize()

In [None]:
def plot(path):
    Ks, Ls, errs = [], [], []
    for t in range(0, NO_ITERS, STEP_ITER):
        path = f'./ckpts/{d}/bpe_iter={t}.pkl'
        if not os.path.exists(path):
            break
        bpe = pickle.load(open(path, 'rb'))
        usage = [len(t.bond_to_token) for t in bpe.tokenizers]
        N = len(bpe.tokenizers)
        K = len(bpe._tokens)
        L = np.mean(usage)
        errors = []
        for i in tqdm(range(min(N, 10))):
            try:
                error = compare(bpe.tokenizers[i], ref.tokenizers[i])
            except:
                print(i)
                raise
            errors.append(error)
        err = np.mean(errors)
        errs.append(err)
        Ks.append(K)
        Ls.append(L)

    Ks = np.array(Ks)
    Ls = np.array(Ls)
    errs = np.array(errs)
    N = len(Ks)

    # make figure + first (left) axis
    fig, ax1 = plt.subplots(figsize=(8, 5))

    # plot L vs K on left y-axis
    x_diag = np.linspace(Ks.min(), Ks.max(), 100)
    ax1.plot(x_diag, x_diag/ratio, linestyle='--', label=f"L=K (K/L={ratio:.1f})")
    ax1.plot(Ks, Ls, marker='o', label="L vs K", linewidth=2)
    ax1.set_xlabel("K (Vocab Size) Each Round")
    ax1.set_ylabel("L  (#Motif-Tokens Per PDB)")
    ax1.set_xticks(Ks)

    # create a second y-axis that shares the same x
    ax2 = ax1.twinx()
    ax2.plot(Ks, errs, marker='x', linestyle=':', label="Error", linewidth=2, color="tab:red")
    ax2.set_ylabel("Error", color="tab:red")
    ax2.tick_params(axis="y", labelcolor="tab:red")

    # combine legends from both axes
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc="best")

    ax1.set_title(f"L vs K for N={N} w/ {len(Ks)} BPE rounds")
    fig.tight_layout()
    plt.show()  
    plt.savefig(path) 

# plot(f'./ckpts/{d}/run.png') 
bpe = pickle.load(open(f'./ckpts/{d}/bpe_iter=0.pkl', 'rb'))

In [None]:
# len(bpe.tokenizers), len(ref.tokenizers)
# len(pickle.load(open("./ckpts/1751936564.1540673/bpe_iter=100.pkl", "rb")).tokenizers)
len(cleaned_structures)

In [None]:
t = 0
path = f'./ckpts/{d}/bpe_iter={t}.pkl'
bpe = pickle.load(open(path, 'rb'))


In [None]:
for index in range(len(bpe.tokenizers)):
    t = bpe.tokenizers[index]
    for k in modified(t):
        print(index, k, t.bond_to_token[k])

        
index = 6
t = bpe.tokenizers[index] 

In [None]:
start, length = 69, 6
occur = (30, 8)

In [None]:
path = os.path.abspath('../test.png')
ref_path = os.path.abspath('../ref.png')
bond_path = os.path.abspath('../test_bonds.png')
ref_bond_path = os.path.abspath('../ref_bonds.png')
t.visualize(path)
ref.tokenizers[index].visualize(ref_path)
t.visualize_bonds(start, length, bond_path)
ref.tokenizers[index].visualize_bonds(start, length, ref_bond_path)
vis_images(ref_bond_path, bond_path)
# vis_images(*([bond_path] + [f'./ckpts/{d}/key_iter=0_{i}.png' for i in range(10)]))

In [None]:
vis_images(ref_path, path)

In [None]:
t.token_geo(start, length), bpe._tokens[occur]

In [None]:
ref.tokenizers[index].fname, t.fname

In [None]:
full = t.token_geo(0, 3*t.n-1)
tokenized = t.tokenize()
repl = bpe.recover(tokenized)
assert full == repl
bpe.quantize(tokenized)
tokenized

In [None]:
struc = cleaned_structures[0]['angles']

In [None]:
bpe.tokenizers[0].n

In [None]:
len(cleaned_structures), len(ref.tokenizers)

In [None]:
ref.tokenizers[0]._angles_and_dists

In [None]:
cleaned_structures[0]['angles']

In [None]:
ref.tokenizers[0].n

In [None]:
# cleaned_structures[0]
t.angles_and_dists["0C:1N"]

In [None]:
t._angles_and_dists

In [None]:
import pandas as pd
from foldingdiff.tokenizer import *
def init_structure(n):
    angles = {
        "0C:1N": [0. for _ in range(n)],
        "N:CA": [0. for _ in range(n)],
        "CA:C": [0. for _ in range(n)],
        "phi": [np.nan for _ in range(n)],
        "psi": [np.nan for _ in range(n)],
        "omega": [np.nan for _ in range(n)],
        "tau": [np.nan for _ in range(n)],
        "CA:C:1N": [np.nan for _ in range(n)],
        "C:1N:1CA": [np.nan for _ in range(n)]
    }
    idxes = sum([[i,i,i] for i in range(1, n+1)], [])
    return {
        "angles": pd.DataFrame(angles),
        "coords": None,
        "c_beta": None,
        "full_idxes": idxes,
        "full_coords": None,
        "side_chain": None,
        "aa": None,
        "fname": None
    }

def recover_structure(repl):
    n = len(repl["N:CA"])
    struc = init_structure(n)
    # ref.tokenizers[0].angles_and_dists
    struc["angles"]["N:CA"].iloc[:-1] = repl["N:CA"][1:]
    struc["angles"]["CA:C"].iloc[:-1] = repl["CA:C"][1:]
    struc["angles"]["0C:1N"].iloc[:-1] = repl["0C:1N"]
    struc["angles"]["phi"].iloc[1:] = repl["phi"]
    struc["angles"]["psi"].iloc[:-1] = repl["psi"]
    struc["angles"]["omega"].iloc[:-1] = repl["omega"]
    struc["angles"]["tau"].iloc[:-1] = repl["tau"][1:]
    struc["angles"]["CA:C:1N"].iloc[:-1] = repl["CA:C:1N"]
    struc["angles"]["C:1N:1CA"].iloc[:-1] = repl["C:1N:1CA"]
    t_new = Tokenizer(struc)
    t_new.bond_to_token = {}
    cur = 0
    for key, *pargs in tokenized:    
        if key == "MOTIF":
            token_id = pargs[0]
            nb = Tokenizer.num_bonds(bpe._tokens[token_id]) 
            t_new.bond_to_token[cur] = (cur, token_id, nb)
            cur += nb
    return t_new

In [None]:
tokenized