In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Image, display
import matplotlib.image as mpimg
from tqdm import tqdm
import math
import os
import json
from types import SimpleNamespace
from foldingdiff.tokenizer import Tokenizer
from collections import defaultdict
from foldingdiff.datasets import *
from foldingdiff.utils import *
from foldingdiff.plotting import plot
os.chdir(Path.cwd().parents[0])
from bin.encode import BPE

In [None]:
def modified(t):
    mod = []
    for k, v in t.bond_to_token.items():
        if isinstance(v[1], tuple):
            mod.append(k)
    return mod


def compare(t1, t2):
    return compute_rmsd(t1.compute_coords(), t2.compute_coords())


def vis_images(*paths, scale=4):
    """
    Display an arbitrary number of images in a square-ish grid layout.

    Parameters:
    *paths: variable number of file paths to images
    """
    n = len(paths)
    if n == 0:
        print("No images to display.")
        return

    # Determine grid size (close to square)
    n_cols = math.ceil(math.sqrt(n))
    n_rows = math.ceil(n / n_cols)

    # Create subplots
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * scale, n_rows * scale))

    # Flatten axes array for easy iteration
    if isinstance(axes, plt.Axes):
        axes = [axes]
    else:
        axes = axes.flatten()

    # Display each image
    for ax, path in zip(axes, paths):
        img = mpimg.imread(path)
        ax.imshow(img)
        ax.set_title(path.split("/")[-1])
        ax.axis('off')

    # Hide any unused subplots
    for ax in axes[len(paths):]:
        ax.axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
vis_images('ckpts/1752521366.505088/run_iter=50.png', scale=10)

In [None]:
d = "1752566951.7877476"
# d = "1752336941.5962725"
# d = "1752336941.5956001"
# d = "1752293573.2959695"

args_path = f"./ckpts/{d}/args.txt"
args = load_args_from_txt(args_path)
args.__dict__

In [None]:
import re
# folders = ["./ckpts/1752293573.2959695",           
#            "./ckpts/1752293573.2966762", 
#            "./ckpts/1752336941.5962725", 
#            "./ckpts/1752336941.5956001"]
folders = ["./ckpts/1752566951.7877476"]
paths = []
for folder in folders:
    p = Path(folder)
    # find all run_iter PNGs
    pngs = list(p.glob("run_iter=*.png"))
    if not pngs:
        print(f"No run_iter PNGs in {folder!r}, skipping.")
        continue

    # helper to extract the integer after run_iter=
    def iter_num(fp: Path):
        m = re.search(r"run_iter=(\d+)\.png$", fp.name)
        return int(m.group(1)) if m else -1

    # pick the file with the max iteration
    latest = max(pngs, key=iter_num)
    ref_coords = np.load(os.path.join(folder, "ref_coords.npy"), allow_pickle=True)
    run_path = os.path.join(folder, f"run_iter={latest}.png")            
    plot(ref_coords, p.name, run_path, no_iters=iter_num(latest), step_iter=args.save_every, ratio=None)
    print(f"Latest in {folder!r} â†’ {latest}")
    paths.append(str(latest))

# now display them
vis_images(*paths)

In [None]:
dataset = FullCathCanonicalCoordsDataset(args.data_dir, 
                                         use_cache=False, 
                                         debug=False, 
                                         zero_center=False, 
                                         toy=args.toy, 
                                         pad=args.pad, 
                                         secondary=args.sec)     
cleaned_structures = []
for i, struc in enumerate(dataset.structures):
    if (struc['angles']['psi']==struc['angles']['psi']).sum() < len(struc['angles']['psi'])-1:
        print(f"skipping {i}, {struc['fname']} because of missing dihedrals")
    else:
        cleaned_structures.append(struc)
dataset.structures = cleaned_structures
ref = BPE(dataset.structures, 
            bins=args.bins, 
            bin_strategy=args.bin_strategy, 
            save_dir=f'./ckpts/{d}',
            rmsd_partition_min_size=args.p_min_size,
            num_partitions=args.num_p,
            compute_sec_structs=args.sec, 
            plot_iou_with_sec_structs=args.sec_eval,                  
            res_init=args.res_init)
ref.initialize()

In [None]:
# len(bpe.tokenizers), len(ref.tokenizers)
# len(pickle.load(open("./ckpts/1751936564.1540673/bpe_iter=100.pkl", "rb")).tokenizers)
len(cleaned_structures)

In [None]:
t = 0
path = f'./ckpts/{d}/bpe_iter={t}.pkl'
bpe = pickle.load(open(path, 'rb'))


In [None]:
for index in range(len(bpe.tokenizers)):
    t = bpe.tokenizers[index]
    for k in modified(t):
        print(index, k, t.bond_to_token[k])

        
index = 6
t = bpe.tokenizers[index] 

In [None]:
start, length = 69, 6
occur = (30, 8)

In [None]:
path = os.path.abspath('../test.png')
ref_path = os.path.abspath('../ref.png')
bond_path = os.path.abspath('../test_bonds.png')
ref_bond_path = os.path.abspath('../ref_bonds.png)
t.visualize(patref.tokenizers[index].visualize(ref_path)
t.visualize_bonds(start, length, bond_path)
ref.tokenizers[index].visualize_bonds(start, length, ref_bond_path)
vis_images(ref_bond_path, bond_path)
# vis_images(*([bond_path] + [f'./ckpts/{d}/key_iter=0_{i}.png' for i in range(10)]))

In [None]:
vis_images(ref_path, path)

In [None]:
t.token_geo(start, length), bpe._tokens[occur]

In [None]:
ref.tokenizers[index].fname, t.fname

In [None]:
full = t.token_geo(0, 3*t.n-1)
tokenized = t.tokenize()
repl = bpe.recover(tokenized)
assert full == repl
bpe.quantize(tokenized)
tokenized

In [None]:
struc = cleaned_structures[0]['angles']

In [None]:
bpe.tokenizers[0].n, len(repl["0C:1N"]), len(repl["N:CA"])

In [None]:
len(cleaned_structures), len(ref.tokenizers)

In [None]:
ref.tokenizers[0]._angles_and_dists

In [None]:
cleaned_structures[0]['angles']

In [None]:
ref.tokenizers[0].n

In [None]:
# cleaned_structures[0]
t.angles_and_dists["0C:1N"]

In [None]:
t._angles_and_dists

In [None]:
tokenized