In [8]:
from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable
import os
import string
from pathlib import Path

from scipy.spatial.distance import squareform, pdist, cdist
import biotite.structure as bs
from biotite.structure.io.pdbx import PDBxFile, get_structure
from biotite.database import rcsb

import torch
import torch.nn as nn
import numpy as np 
import pandas as pd
import matplotlib as mpl

from tqdm import tqdm
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.Alphabet import SingleLetterAlphabet

import copy

import sys
sys.path.append('..')
from linear_quant import *

import esm

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fc18329d760>

In [9]:
# This is an efficient way to delete lowercase characters and insertion characters from a string
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)

def read_sequence(filename: str) -> Tuple[str, str]:
    """ Reads the first (reference) sequences from a fasta or MSA file."""
    record = next(SeqIO.parse(filename, "fasta"))
    return record.description, str(record.seq)

def remove_insertions(sequence: str) -> str:
    """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
    return sequence.translate(translation)

def read_msa(filename: str) -> List[Tuple[str, str]]:
    """ Reads the sequences from an MSA file, automatically removes insertions."""
    return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")]

In [10]:
def extend(a, b, c, L, A, D):
    """
    input:  3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral
    output: 4th coord
    """

    def normalize(x):
        return x / np.linalg.norm(x, ord=2, axis=-1, keepdims=True)

    bc = normalize(b - c)
    n = normalize(np.cross(b - a, bc))
    m = [bc, np.cross(n, bc), n]
    d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]
    return c + sum([m * d for m, d in zip(m, d)])


def contacts_from_pdb(
    structure: bs.AtomArray,
    distance_threshold: float = 8.0,
    chain: Optional[str] = None,
) -> np.ndarray:
    mask = ~structure.hetero
    if chain is not None:
        mask &= structure.chain_id == chain

    N = structure.coord[mask & (structure.atom_name == "N")]
    CA = structure.coord[mask & (structure.atom_name == "CA")]
    C = structure.coord[mask & (structure.atom_name == "C")]

    Cbeta = extend(C, N, CA, 1.522, 1.927, -2.143)
    dist = squareform(pdist(Cbeta))
    
    contacts = dist < distance_threshold
    contacts = contacts.astype(np.int64)
    contacts[np.isnan(dist)] = -1
    return contacts

In [11]:
# Select sequences from the MSA to maximize the hamming distance
# Alternatively, can use hhfilter 
def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]:
    assert mode in ("max", "min")
    if len(msa) <= num_seqs:
        return msa
    
    array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8)

    optfunc = np.argmax if mode == "max" else np.argmin
    all_indices = np.arange(len(msa))
    indices = [0]
    pairwise_distances = np.zeros((0, len(msa)))
    for _ in range(num_seqs - 1):
        dist = cdist(array[indices[-1:]], array, "hamming")
        pairwise_distances = np.concatenate([pairwise_distances, dist])
        shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
        shifted_index = optfunc(shifted_distance)
        index = np.delete(all_indices, indices)[shifted_index]
        indices.append(index)
    indices = sorted(indices)
    return [msa[idx] for idx in indices]

In [12]:
def compute_precisions(
    predictions: torch.Tensor,
    targets: torch.Tensor,
    src_lengths: Optional[torch.Tensor] = None,
    minsep: int = 6,
    maxsep: Optional[int] = None,
    override_length: Optional[int] = None,  # for casp
):
    if isinstance(predictions, np.ndarray):
        predictions = torch.from_numpy(predictions)
    if isinstance(targets, np.ndarray):
        targets = torch.from_numpy(targets)
    if predictions.dim() == 2:
        predictions = predictions.unsqueeze(0)
    if targets.dim() == 2:
        targets = targets.unsqueeze(0)
    override_length = (targets[0, 0] >= 0).sum()

    # Check sizes
    if predictions.size() != targets.size():
        raise ValueError(
            f"Size mismatch. Received predictions of size {predictions.size()}, "
            f"targets of size {targets.size()}"
        )
    device = predictions.device

    batch_size, seqlen, _ = predictions.size()
    seqlen_range = torch.arange(seqlen, device=device)

    sep = seqlen_range.unsqueeze(0) - seqlen_range.unsqueeze(1)
    sep = sep.unsqueeze(0)
    valid_mask = sep >= minsep
    valid_mask = valid_mask & (targets >= 0)  # negative targets are invalid

    if maxsep is not None:
        valid_mask &= sep < maxsep

    if src_lengths is not None:
        valid = seqlen_range.unsqueeze(0) < src_lengths.unsqueeze(1)
        valid_mask &= valid.unsqueeze(1) & valid.unsqueeze(2)
    else:
        src_lengths = torch.full([batch_size], seqlen, device=device, dtype=torch.long)

    predictions = predictions.masked_fill(~valid_mask, float("-inf"))

    x_ind, y_ind = np.triu_indices(seqlen, minsep)
    predictions_upper = predictions[:, x_ind, y_ind]
    targets_upper = targets[:, x_ind, y_ind]

    topk = seqlen if override_length is None else max(seqlen, override_length)
    indices = predictions_upper.argsort(dim=-1, descending=True)[:, :topk]
    topk_targets = targets_upper[torch.arange(batch_size).unsqueeze(1), indices]
    if topk_targets.size(1) < topk:
        topk_targets = F.pad(topk_targets, [0, topk - topk_targets.size(1)])

    cumulative_dist = topk_targets.type_as(predictions).cumsum(-1)

    gather_lengths = src_lengths.unsqueeze(1)
    if override_length is not None:
        gather_lengths = override_length * torch.ones_like(
            gather_lengths, device=device
        )

    gather_indices = (
        torch.arange(0.1, 1.1, 0.1, device=device).unsqueeze(0) * gather_lengths
    ).type(torch.long) - 1

    binned_cumulative_dist = cumulative_dist.gather(1, gather_indices)
    binned_precisions = binned_cumulative_dist / (gather_indices + 1).type_as(
        binned_cumulative_dist
    )

    pl5 = binned_precisions[:, 1]
    pl2 = binned_precisions[:, 4]
    pl = binned_precisions[:, 9]
    auc = binned_precisions.mean(-1)

    return {"AUC": auc, "P@L": pl, "P@L2": pl2, "P@L5": pl5}


def evaluate_prediction(
    predictions: torch.Tensor,
    targets: torch.Tensor,
) -> Dict[str, float]:
    if isinstance(targets, np.ndarray):
        targets = torch.from_numpy(targets)
    contact_ranges = [
        ("local", 3, 6),
        ("short", 6, 12),
        ("medium", 12, 24),
        ("long", 24, None),
    ]
    metrics = {}
    targets = targets.to(predictions.device)
    for name, minsep, maxsep in contact_ranges:
        rangemetrics = compute_precisions(
            predictions,
            targets,
            minsep=minsep,
            maxsep=maxsep,
        )
        for key, val in rangemetrics.items():
            metrics[f"{name}_{key}"] = val.item()
    return metrics

In [13]:
"""Adapted from: https://github.com/rmrao/evo/blob/main/evo/visualize.py"""
def plot_contacts_and_predictions(
    predictions: Union[torch.Tensor, np.ndarray],
    contacts: Union[torch.Tensor, np.ndarray],
    ax: Optional[mpl.axes.Axes] = None,
    # artists: Optional[ContactAndPredictionArtists] = None,
    cmap: str = "Blues",
    ms: float = 1,
    title: Union[bool, str, Callable[[float], str]] = True,
    animated: bool = False,
) -> None:

    if isinstance(predictions, torch.Tensor):
        predictions = predictions.detach().cpu().numpy()
    if isinstance(contacts, torch.Tensor):
        contacts = contacts.detach().cpu().numpy()
    if ax is None:
        ax = plt.gca()

    seqlen = contacts.shape[0]
    relative_distance = np.add.outer(-np.arange(seqlen), np.arange(seqlen))
    bottom_mask = relative_distance < 0
    masked_image = np.ma.masked_where(bottom_mask, predictions)
    invalid_mask = np.abs(np.add.outer(np.arange(seqlen), -np.arange(seqlen))) < 6
    predictions = predictions.copy()
    predictions[invalid_mask] = float("-inf")

    topl_val = np.sort(predictions.reshape(-1))[-seqlen]
    pred_contacts = predictions >= topl_val
    true_positives = contacts & pred_contacts & ~bottom_mask
    false_positives = ~contacts & pred_contacts & ~bottom_mask
    other_contacts = contacts & ~pred_contacts & ~bottom_mask

    if isinstance(title, str):
        title_text: Optional[str] = title
    elif title:
        long_range_pl = compute_precisions(predictions, contacts, minsep=24)[
            "P@L"
        ].item()
        if callable(title):
            title_text = title(long_range_pl)
        else:
            title_text = f"Long Range P@L: {100 * long_range_pl:0.1f}"
    else:
        title_text = None

    img = ax.imshow(masked_image, cmap=cmap, animated=animated)
    oc = ax.plot(*np.where(other_contacts), "o", c="grey", ms=ms)[0]
    fn = ax.plot(*np.where(false_positives), "o", c="r", ms=ms)[0]
    tp = ax.plot(*np.where(true_positives), "o", c="b", ms=ms)[0]
    ti = ax.set_title(title_text) if title_text is not None else None
    # artists = ContactAndPredictionArtists(img, oc, fn, tp, ti)

    ax.axis("square")
    ax.set_xlim([0, seqlen])
    ax.set_ylim([0, seqlen])

In [14]:
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt
import os
import torch

import esm
from esm.data import ESMStructuralSplitDataset

for split_level in ['family', 'superfamily', 'fold']:
    for cv_partition in ['0', '1', '2', '3', '4']:
        esm_structural_train = ESMStructuralSplitDataset(
            split_level=split_level, 
            cv_partition=cv_partition, 
            split='train', 
            root_path = os.path.expanduser('~/.cache/torch/data/esm'),
            download=True
        )
        esm_structural_valid = ESMStructuralSplitDataset(
            split_level=split_level, 
            cv_partition=cv_partition, 
            split='valid', 
            root_path = os.path.expanduser('~/.cache/torch/data/esm'),
            download=True
        )

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files alread

In [15]:
esm2, esm2_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm2 = esm2.eval().cpu()
esm2_batch_converter = esm2_alphabet.get_batch_converter()

In [16]:
quant_esm2 = quant_model_acts(esm2, 0, True, exclude_part=["base_model"])
quant_esm2 = quant_esm2.cuda("cuda:1")

In [17]:
len(esm_structural_train)

11207

In [18]:
len(esm_structural_valid)

4090

In [19]:
for i in range(50):
    ele = esm_structural_train[i]
    data =  [(i, esm_structural_train[i]['seq'])]
    batch_labels, batch_strs, batch_tokens = esm2_batch_converter(data)
    batch_lens = (batch_tokens != esm2_alphabet.padding_idx).sum(1)
    with torch.no_grad():
        results = quant_esm2(batch_tokens.cuda("cuda:1"), repr_layers=[33], return_contacts=True)
    
os.makedirs('./output/stats/', exist_ok=True)
act_stats_save_path = './output/stats/act_stats_full_pre_50_v1.pth'
act_dict = save_model_act_stats(quant_esm2, act_stats_save_path)

In [20]:
esm2, esm2_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm2 = esm2.eval().cpu()
# new_model = quant_model_acts(new_model, 8, False, exclude_part=["base_model"], cali_batch_size=50)
quant_esm2 = quant_model_acts(esm2, 8, False, exclude_part=["base_model"], cali_batch_size=50, quant_scheme="pwlq-3")

In [21]:
mode = load_model_act_stats(quant_esm2, act_stats_save_path, act_clip_method="top_15")
# mode = load_model_act_stats(new_model, act_stats_save_path, act_clip_method="clip_0.999")

In [22]:
esm2 = esm2.cuda("cuda:1")
esm2_results = []

for i in range(4090):
    if i % 200 == 0:
        print(i)
    ele = esm_structural_valid[i]
    data =  [(i, esm_structural_valid[i]['seq'])]
    batch_labels, batch_strs, batch_tokens = esm2_batch_converter(data)
    batch_lens = (batch_tokens != esm2_alphabet.padding_idx).sum(1)
    with torch.no_grad():
        results = esm2(batch_tokens.cuda("cuda:1"), repr_layers=[33], return_contacts=True)
        
    metrics = {"id": i, "model": "ESM-2 (Unsupervised)"}
    metrics.update(evaluate_prediction(results["contacts"], ele['dist'] < 8))
    esm2_results.append(metrics)

esm2_results = pd.DataFrame(esm2_results)

# 4090
# 0.6377077553375717
# 0.501526620944001

0
200
400
600
800
1000
1200
1400
1600
1800
2000
2200
2400
2600
2800
3000
3200
3400
3600
3800
4000


In [23]:
print(esm2_results["local_P@L5"].mean())
print(esm2_results["local_P@L2"].mean())
print(esm2_results["local_P@L"].mean())

0.7631056543612684
0.6361556842394883
0.501526620944001


In [None]:
quant_esm2_results = []
quant_esm2 = quant_esm2.cuda("cuda:1")

for i in range(4090):
    if i % 200 == 0:
        print(i)
    ele = esm_structural_valid[i]
    data =  [(i, esm_structural_valid[i]['seq'])]
    batch_labels, batch_strs, batch_tokens = esm2_batch_converter(data)
    batch_lens = (batch_tokens != esm2_alphabet.padding_idx).sum(1)
    with torch.no_grad():
        results = quant_esm2(batch_tokens.cuda("cuda:1"), repr_layers=[33], return_contacts=True)
        
    metrics = {"id": i, "model": "ESM-2 (Unsupervised)"}
    metrics.update(evaluate_prediction(results["contacts"], ele['dist'] < 8))
    quant_esm2_results.append(metrics)
    
quant_esm2_results = pd.DataFrame(quant_esm2_results)

0
200
400
600
800
1000
1200
1400
1600
1800
2000
2200
2400
2600
2800
3000
3200


In [None]:
print(quant_esm2_results["local_P@L5"].mean())
print(quant_esm2_results["local_P@L2"].mean())
print(quant_esm2_results["local_P@L"].mean())