In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["TF_USE_NVLINK_FOR_PARALLEL_COMPILATION"] = "0"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

ENV = {"TF_FORCE_UNIFIED_MEMORY":"1", "XLA_PYTHON_CLIENT_MEM_FRACTION":"4.0"}
for k,v in ENV.items():
  os.environ[k] = v

In [2]:
import numpy as np
import pickle
DATASET = "2k"
if DATASET == "2k":
  with open("new_labels_2k.pkl","rb") as handle:
    LABELS = pickle.load(handle)
  with open("assignments_2k.pkl","rb") as handle:
    CROSS_ASSIGNED = pickle.load(handle)
  TMS = {}
  for line in open("tmscores_2k.txt","r"):
    a,b,tm_a,tm_b = line.rstrip().split()
    tmscore = max(float(tm_a),float(tm_b))
    if a not in TMS: TMS[a] = {}
    if b not in TMS: TMS[b] = {}
    TMS[a][b] = tmscore
    TMS[b][a] = tmscore

In [3]:
import random
def get_data(M="split",
             feats="pair_A",
             pad_len=500,
             mask_alt=False,
             seed=0):
        
  data = [[],[],[]]
  labels = [[],[],[]]
  sample_weight = [[],[],[]]
  info = [[],[],[]]
  
  dist_bins = np.append(np.linspace(3.2,6.0,15),np.inf)
  for a,v in CROSS_ASSIGNED[seed].items():
    info[v].append(a)
    
    # get labels
    y = LABELS[a]
    L = y["dist_all"].shape[0]
    if L <= 500:
      y_dist = np.sum(dist_bins < y["dist_all"][...,None],-1)

      # all ligands
      y_bind_all = y["dist_all"] < 5.0

      # confident ligands
      y_bind_sub = y["dist_sub"] < 5.0

      # homologous ligands (with TMscore > 0.8)
      y_bind_alt = y["bind_alt"]

      # adjust mask
      y_mask = y["mask"].copy()
      y_mask[np.logical_and(y_bind_all == True, y_bind_sub == False)] = False

      if mask_alt:
        # mask alternative binding positions
        y_mask[np.logical_and(y_bind_all == False, y_bind_alt == True)] = False  

      y = np.stack([y_bind_all, y_mask, y_dist],-1)

      # gather data
      if M == "esm2":
        x = np.load(f"embeddings/esm2/{a}.npz")
        if feats == "last":
          x = x["rep"][x["idx"],-1]
        else:
          x = x["rep"][x["idx"]]
      elif feats == "pair_B":
        x = np.load(f"embeddings/{M}/{a}.npz")["pair_B"].swapaxes(0,1)
      else:
        x = np.load(f"embeddings/{M}/{a}.npz")[feats]

      # reshape
      x = x.reshape(x.shape[0],-1)

      # pad
      l = x.shape[0]
      l_ = y.shape[0]
      if l != l_: print(a)
      data[v].append(np.pad(x,[[0,pad_len-l],[0,0]]))
      labels[v].append(np.pad(y,[[0,pad_len-l],[0,0]]))    

      # weight
      counts = 1
      for b,_ in CROSS_ASSIGNED[seed].items():
        if a != b:
          if TMS[a][b] > 0.5: counts += 1

      sample_weight[v].append(1/counts)
    else:
      print(a,L)
  
  # combine
  data = [np.array(v) for v in data]
  labels = [np.array(v) for v in labels]
  sample_weight = [np.array(v) for v in sample_weight]
    
  return data, labels, sample_weight

In [5]:
import jax
import jax.numpy as jnp
import haiku as hk
import optax
import numpy as np
from tqdm import tqdm
from jax import tree_util

def l2_regularization(params, lmbda=1.0):
    """Compute L2 regularization for a nested dictionary of parameters."""
    
    # Function to square and sum each array
    squared_and_summed = tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), params)
    
    # Summing all the individual sums
    total_l2_penalty = sum(tree_util.tree_flatten(squared_and_summed)[0])
    
    return lmbda * total_l2_penalty


class mk_model:
  def __init__(self, data, labels, sample_weight, normalize=True, lam=0.01):
      
      self.data = data
      self.labels = labels
      self.sample_weight = sample_weight
      self.normalize = normalize
      self.history = []
      self.lam = lam

      # Normalize data
      if normalize:
        mask = sample_weight[0][:,None,None] * labels[0][...,1][:,:,None]
        self._mean = (data[0] * mask).sum((0,1)) / mask.sum((0,1))
        self._std = np.sqrt((np.square(data[0] - self._mean) * mask).sum((0,1)) / mask.sum((0,1)))

      self.F = data[0].shape[-1]
      
      # copy data to GPU
      self._data = [jnp.array(x) for x in self.data]
      self._labels = [jnp.array(x) for x in self.labels]
      self._sample_weight = [jnp.array(x) for x in self.sample_weight]

      
      # Define the model function and initialize it
      self.model = hk.without_apply_rng(hk.transform(self._build_model))
      rng = jax.random.PRNGKey(42)
      self.params = self.model.init(rng, data[0])

      # Define optimizer
      self.recompile(1e-3)

  def _build_model(self, x):
    if self.normalize:
      mean = hk.get_parameter("mean", shape=(self.F,), init=hk.initializers.Constant(self._mean))
      std = hk.get_parameter("std", shape=(self.F,), init=hk.initializers.Constant(self._std))
      x = (x - jax.lax.stop_gradient(mean)) / jax.lax.stop_gradient(std)
        
    # final layer
    x = jax.nn.sigmoid(hk.Linear(1, w_init=hk.initializers.Constant(0))(x))
    return x

  def loss_fn(self, params, inputs, targets, sample_weights):
      """Custom loss function."""
      predictions = self.model.apply(params, inputs)
      y_true_bce = targets[..., 0]
      y_mask = targets[..., 1]
      loss = -y_true_bce * jnp.log(predictions[..., 0] + 1e-7) - (1 - y_true_bce) * jnp.log(1 - predictions[..., 0] + 1e-7)
      masked_loss = jnp.sum(loss * y_mask, axis=-1) / (jnp.sum(y_mask, axis=-1) + 1e-7)
      weighted_loss = (masked_loss * sample_weights).sum() / (sample_weights.sum() + 1e-7)

      
      # l2
      l2_loss = []
      for _,p in params.items():
        if "w" in p:
          l2_loss.append(jnp.square(p["w"]).sum())
      # Combine all losses
      total_loss = weighted_loss + self.lam * sum(l2_loss)

      return total_loss

  def fit(self, epochs=1, batch_size=64, verbose=True):
      num_samples = len(self._data[0])
      indices = np.arange(num_samples)
      for epoch in range(epochs):
          # Shuffle indices each epoch
          np.random.shuffle(indices)
          running_loss = 0
          steps_per_epoch = num_samples // batch_size + int(num_samples % batch_size != 0)
          pbar = tqdm(total=steps_per_epoch, disable=not verbose, dynamic_ncols=True)
          for i in range(0, num_samples, batch_size):
              batch_indices = indices[i:i + batch_size]
              batch_data = self._data[0][batch_indices]
              batch_labels = self._labels[0][batch_indices]
              batch_sample_weights = self._sample_weight[0][batch_indices]
              self.params, self.opt_state, loss = self.update(self.params, self.opt_state, batch_data, batch_labels, batch_sample_weights)
              running_loss += loss
              avg_loss = running_loss / (i // batch_size + 1)
              pbar.set_description(f"loss: {avg_loss:.4f}")
              pbar.update(1)
          pbar.close()

  def predict(self, mode=0, batch_size=12, verbose=True):
    data = self.data[mode]
    num_samples = len(data)
    indices = np.arange(num_samples)    
    predictions = []    
    pbar = tqdm(total=num_samples, disable=not verbose, desc="Predicting", dynamic_ncols=True)    
    for i in range(0, num_samples, batch_size):
        batch_indices = indices[i:i+batch_size]
        batch_data = data[batch_indices]
        batch_predictions = self.model.apply(self.params, batch_data)        
        predictions.append(batch_predictions)        
        pbar.update(len(batch_data))        
    pbar.close()    
    return np.concatenate(predictions, axis=0)
  
  def recompile(self, learning_rate=None):
    if learning_rate is not None:
      self.optimizer = optax.adam(learning_rate)
      self.opt_state = self.optimizer.init(self.params)
      def _update(params, opt_state, inputs, targets, sample_weights):
        loss_val, grads = jax.value_and_grad(self.loss_fn)(params, inputs, targets, sample_weights)
        updates, opt_state = self.optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, opt_state, loss_val
      self.update = jax.jit(_update)

  def recovery(self, mode=0, mean=True, batch_size=12, k=None,):
    # Predict
    predictions = self.predict(mode, batch_size=batch_size, verbose=False)
    predictions = predictions[...,0]

    true_labels = self.labels[mode][...,0]
    mask = self.labels[mode][...,1]
    weights = self.sample_weight[mode]

    t_rec = []
    p_rec = []

    for n in range(len(true_labels)):
      mask_ = mask[n] == True
      pred_ = predictions[n][mask_]
      true_ = true_labels[n][mask_]

      top_k = sum(true_) if k is None else k
      sorted_indices = pred_.argsort()[::-1][:top_k]
      top_k_pred = pred_[sorted_indices]
      top_k_true = true_[sorted_indices]

      t_rec.append(top_k_true.mean())
      p_rec.append(top_k_pred.mean())

    if mean:
      t_rec = (np.array(t_rec) * weights).sum() / weights.sum()
      p_rec = (np.array(p_rec) * weights).sum() / weights.sum()

    return [p_rec,t_rec]

  def evaluate(self):
      self.history.append([self.recovery(0), self.recovery(1), self.recovery(2)])

In [6]:
def clear_mem():
  backend = jax.lib.xla_bridge.get_backend()
  for buf in backend.live_buffers(): buf.delete()

In [7]:
import gc
import numpy as np
import random

In [None]:
MODE = f"attempt_7_{DATASET}"
LAM = 0.03
lam_str = str(LAM).replace(".","-")
NAME = f"{MODE}_lam{lam_str}"
os.makedirs(f"fix/{NAME}",exist_ok=True)
NORMALIZE_BY_TRAIN = True
MASK_ALT = False
REDO = False
H10 = {}
for X in [
  [["split_nosc","pair_A"],["split_nosc","pair_B"]],
]:
  for seed in range(10):
    clear_mem()
    name = "_".join(sum(X,[]))+f"_{seed}"
    print(name)    
    if REDO or not os.path.isfile(f"fix/{NAME}/"+name+".history.npy"):
      data, labels, sample_weight = get_data(*X[0], seed=seed, mask_alt=MASK_ALT)
      print(X[0])
      if len(X) > 1:
        for x in X[1:]:
          print(x)
          data_,_,_ = get_data(*x,
                               seed=seed,
                               mask_alt=MASK_ALT) 
          for k,v in enumerate(data):
            data[k] = np.concatenate([v,data_[k]],-1)

      model = mk_model(
              labels = labels,
              data = data,
              sample_weight=sample_weight,
              normalize=NORMALIZE_BY_TRAIN,
              lam=LAM,
            )
      model.recompile(learning_rate=1e-3)
      for _ in range(40):
        model.fit(epochs=4,  batch_size=12, verbose=False)
        model.evaluate()
        print(model.history[-1])
      model.recompile(learning_rate=1e-4)
      for _ in range(40):
        model.fit(epochs=4,  batch_size=12, verbose=False)
        model.evaluate()
        print(model.history[-1])
      h = np.array(model.history)
      filename = f"fix/{NAME}/"+name
      with open(filename+".pickle","wb") as handle:
        pickle.dump(model.params, handle)
      np.save(filename+".history.npy", np.array(model.history))
      del data
      gc.collect()
    else:
      h = np.load(f"fix/{NAME}/"+name+".history.npy")
    H10[name] = h