In [None]:
import os
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, adjusted_rand_score, mutual_info_score
from collections import Counter
os.chdir('/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff')
from bin.learn import *
import torch
import json
from types import SimpleNamespace
from foldingdiff.potential_model import *
from foldingdiff.modelling import *
from foldingdiff.tokenizer import Tokenizer
from foldingdiff.datasets import *
import torch
if torch.cuda.is_available():
    torch.cuda.reset_peak_memory_stats()

In [None]:
# d = '1749783864.896021'
# bad_key = '3lca_A'
# fnames = glob(os.path.join(REPEAT_DIR, "*.pdb"))
# j = next((i for i, v in enumerate(fnames) if bad_key in v), None)
# i = j//100
# path = f"/n/netscratch/mzitnik_lab/Lab/msun415/{d}/feats_100_{i}.pkl"
# print(f"loading {i}")
# stuff = pickle.load(open(path, "rb"))
# # if bad_key in stuff:
# #     print(f"dumping {bad_key}")
# #     pickle.dump(stuff[bad_key], open(f'/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff/ckpts/{d}/{bad_key}.pkl', 'wb+'))
# stuff.keys()

In [None]:
ckpt_dir = "/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff/ckpts/1749783864.878101"
args_path = os.path.join(ckpt_dir, "args.json")
with open(args_path, "r") as f:
    arg_dict = json.load(f)
# 2. Turn it into an object whose keys become attributes
args = SimpleNamespace(**arg_dict)
raw_ds = FullCathCanonicalCoordsDataset(
    args.data_dir, use_cache=False, debug=args.debug,
    zero_center=False, toy=30, pad=args.pad, secondary=False,
    trim_strategy="discard"
)
dataset = [Tokenizer(x) for x in raw_ds.structures]
# maximum token length for Transformer positional encodings
max_len = max([3 * (3 * t.n - 1) - 2 for t in dataset])
device = args.cuda
# ---------------- build model ------------------------------------
model = get_model(args, device, max_len=max_len)           # returns SemiCRFModel
model.to(args.cuda)
if args.config:
    config = json.load(open(args.config))
# compute feats in batches            
dataset = FeatDataset(dataset, args.save_dir)
checkpoint_path, epoch = find_latest_checkpoint(args.save_dir)    
print(checkpoint_path)
ckpt = torch.load(checkpoint_path, map_location=model.device)
if 'model_state' in ckpt:
    model.load_state_dict(ckpt['model_state'])
else:
    model.load_state_dict(ckpt)

In [None]:
for idx in [25]:
    (_, t, feats) = dataset[idx]
    N = t.n
    assert t.n == len(t.aa), "number of residues != length of amino acid sequence"
    coords = t.compute_coords()
    out, attn_scores = model.precompute(
        feats         = feats,
        aa_seq        = t.aa,              # Tokenizer stores AA sequence
        coords_tensor = coords
    )                                       # out[i][l] ready for DP
    log_a, map_a, best_lens = semi_crf_dp_and_map(out, N, gamma=args.gamma)
    best_seg = backtrace_map_segmentation(best_lens, N)
    attn_stack = torch.stack([attn_scores[start][end-start] for start, end in best_seg], axis=0)
    attn_agg = attn_stack.mean(axis=0)    
    t.bond_to_token = {3*start: (3*start, 3*seg_idx, min(3*(end-start), 3*t.n-1-3*start))
                    for seg_idx, (start, end) in enumerate(best_seg)} 
    loss   = -log_a[N]                       # negative logâ€‘partition
    prob = torch.exp(map_a[N] - log_a[N]).item()
    print(idx, prob)
    epoch = -1
    path = Path(os.path.join(args.plot_dir, f"epoch={epoch}_idx={idx}_p={prob:.3f}.png"))
    attn_path = path.with_name(path.stem + "_attn" + path.suffix)
    t.visualize(path, vis_dihedral=False)
    plot_feature_importance(attn_agg.detach().cpu().numpy(), model.aggregator.per_res_labels, attn_path)
    t.bond_to_token.tree.visualize(os.path.join(args.plot_dir, f"epoch={epoch}_idx={idx}_p={prob:.3f}_tree.png"), horizontal_gap=1.0, font_size=6)
    # # start building hierarchy (down)
    for _ in range(20):
        vals = list(map(lambda l: l.value, t.bond_to_token.tree.leaves.values()))
        max_bt = max([val[1] for val in vals])
        best_i, best = (-1, -1), float("-inf")
        for i, (i1, _, l1) in enumerate(vals):
            for j in range(1, l1-1): # split here
                if j%3: continue
                assert (l1-j+1)%3 != 2
                expr = out[i1//3][j//3] + out[(i1+j)//3][(l1-j+1)//3]
                if expr > best:
                    best = expr
                    best_i = (i, j)
        (i1, b1, l1) = vals[best_i[0]]
        t.bond_to_token.tree.split((i1, b1, l1), (i1, max_bt+1, j), (i1+j, max_bt+2, l1-j))
        max_bt += 2
    # # start building hierarchy (up)
    t.bond_to_token.tree.visualize(os.path.join(args.plot_dir, f"epoch={epoch}_idx={idx}_p={prob:.3f}_down.png"), horizontal_gap=0.5, font_size=6)
    for _ in range(20):
        vals = list(t.bond_to_token.values())
        max_bt = max([val[1] for val in vals])
        best_i, best = -1, float("-inf")
        for i, (i1, _, l1) in enumerate(vals):
            if i < len(t.bond_to_token)-1:
                (i2, _, l2) = vals[i+1]
                assert i1+l1 == i2
                try:
                    if out[i1//3][(l1+l2)//3] > best:
                        best = out[i1//3][(l1+l2)//3]
                        best_i = i
                except:
                    print(i1, l1, l2)
                    raise
        if best_i < 0:
            break
        (i1, _, l1) = vals[best_i]
        (i2, _, l2) = vals[best_i+1]
        t.bond_to_token.pop(i2)
        t.bond_to_token[i1] = (i1, max_bt+1, l1+l2)
        max_bt += 1
    
    t.bond_to_token.tree.visualize(os.path.join(args.plot_dir, f"epoch={epoch}_idx={idx}_p={prob:.3f}_up.png"), horizontal_gap=0.5, font_size=6)
    
    

    # epoch = -1
    path = Path(os.path.join(args.plot_dir, f"epoch={epoch}_idx={idx}_p={prob:.3f}_after.png"))
    attn_path = path.with_name(path.stem + "_attn" + path.suffix)
    t.visualize(path, vis_dihedral=False)
    plot_feature_importance(attn_agg.detach().cpu().numpy(), model.aggregator.per_res_labels, attn_path)