# Environment

In [1]:
import numpy as np
import pandas as pd
import cupy as cp

In [2]:
!git clone https://github.com/sinc-lab/exp2GO.git

Cloning into 'exp2GO'...
remote: Enumerating objects: 126, done.[K
remote: Counting objects: 100% (126/126), done.[K
remote: Compressing objects: 100% (113/113), done.[K
remote: Total 126 (delta 31), reused 0 (delta 0), pack-reused 0[K
Receiving objects: 100% (126/126), 38.82 MiB | 17.37 MiB/s, done.
Resolving deltas: 100% (31/31), done.


In [3]:
path = 'exp2GO/data/'

# Data

In [4]:
species = 'ara' # ara dicty yeast

if species=='ara':
  gafnr_1 = '131'
  gafnr00 = '138'
else:
  gafnr_1 = '59'
  gafnr00 = '66'

## Annotations

In [5]:
onto = 'BP'

In [6]:
ancestorsT_1 = pd.read_csv(path + species + '_terms_gaf'+ gafnr_1 +
                           '_obo2016-06-01_' + onto + '_with_expr_EXP_FILTERED_terms_anc.csv',
                           header=None, index_col=0)

#ancestorsT_1 = ancestorsT_1.dropna(axis=0, how='all') # DROP NA VALUES
ancestorsT_1.fillna('',inplace=True)
ancestorsT_1.index = [x.upper() for x in ancestorsT_1.index]

labelsT_1 = np.unique(ancestorsT_1.to_numpy().ravel())
labelsT_1 = labelsT_1[1:]

In [7]:
ancestorsT0 = pd.read_csv(path + species + '_terms_gaf' + gafnr00 +
                          '_obo2016-06-01_' + onto + '_with_expr_EXP_FILTERED_terms_anc.csv',
                          header=None, index_col=0)

ancestorsT0 = ancestorsT0.dropna(axis=0, how='all') # drop NA values 
ancestorsT0.fillna('',inplace=True)
ancestorsT0.index = [x.upper() for x in ancestorsT0.index]

labelsT0 = np.unique(ancestorsT0.to_numpy().ravel())
labelsT0 = labelsT0[1:]

## Distance matrices

In [8]:
De = pd.read_csv(path + species + '_expression_dist_cosine.csv.zip', header=0, index_col=0)

In [9]:
Dg = pd.read_csv(path + species + '_semantic_dist_' + onto + '_rel_min_exp_FILTERED_terms.csv.zip',
                 header=0, index_col=0)
Dg.index = Dg.index.str.upper()
Dg.columns = Dg.columns.str.upper()

### Filter with T0 genes

In [10]:
genes = list(ancestorsT0.index)
#print(len(genes))

In [11]:
#Filter expression matrix according to TO-GO matrix
De = De.filter(items=genes, axis=0)
De = De.filter(items=genes, axis=1)

In [12]:
Dg = Dg.filter(items=genes, axis=0)
Dg = Dg.filter(items=genes, axis=1)

In [13]:
# Filter T-1 genes
ancestorsT_1 = ancestorsT_1.filter(items=genes, axis=0) 

## Detect LKs and NKs

In [14]:
def countLKNCNK(prots_Tini, prots_Tend):
  
  allprots = np.union1d(prots_Tini.index,prots_Tend.index)
  LK = []
  NK = []

  Nlk, Nnc, Nnk = 0, 0, 0
  for prot in allprots:
    terms_1 = set(prots_Tini.loc[prot]) - set([''])
    terms00 = set(prots_Tend.loc[prot]) - set([''])
    if terms_1 == terms00:
      Nnc += 1
    elif len(terms_1) == 0:
      Nnk += 1
      NK.append(prot)
    elif terms_1 < terms00:
      Nlk += 1
      LK.append(prot)
    else:
      # (implicit NC)
      if len(terms_1) == len(terms00):
        print('Warning: protein {} has same number of terms in T-1 than T0, BUT changed!'.format(prot))
      else:
        print('Warning: protein {} has more terms in T-1 ({}) than T0 ({})'.format(prot,len(terms_1),len(terms00)))

  return Nlk, Nnc, Nnk, LK, NK

In [15]:
Nlk, Nnc, Nnk, LK, NK = countLKNCNK(ancestorsT_1, ancestorsT0)

## Convert GO-terms string to numbers

In [16]:
alllabels = list(set(labelsT_1).intersection(set(labelsT0)))
labels = np.zeros((len(alllabels),) , dtype=np.uint64)
for i, lab in enumerate(alllabels):
  labels[i]=int(lab[3:])

In [17]:
# change GO terms to numbers only in ancestors
def ancestors2num(anc_str):
  ancestors = pd.DataFrame()
  for gen in anc_str:
    labs = []
    for i in range(len(anc_str[gen])):
      num=anc_str[gen][i][3:]
      if num != '':
        labs.append(int(num))
      else:
        labs.append(0)
    ancestors[gen] = labs
  ancestors.set_index(anc_str.index, inplace=True)
  return ancestors

In [18]:
ancT_1 = ancestors2num(ancestorsT_1)
ancT0 = ancestors2num(ancestorsT0)

# NNMF d_GO reconstruction

#### Auxiliary functions

In [19]:
def send2back(d,idx):
  # sent to back these cols and rows in d

  Nd = d.shape[0]
  Ns = idx.shape[0]

  idxinv = np.setdiff1d(np.arange(0,Nd),idx)
  idxinvidx = np.concatenate((idxinv, idx)) # new order
  
  dx = np.zeros(d.shape)

  dx[np.ix_(range(Nd-Ns),range(Nd-Ns))] = d[np.ix_(idxinv,idxinv)] # compacted left-up submatrix
  dx[np.ix_(range(Nd-Ns,Nd),range(Nd))] = d[np.ix_(idx,idxinvidx)] # last rows
  dx[np.ix_(range(Nd),range(Nd-Ns,Nd))] = d[np.ix_(idxinvidx,idx)] # last cols

  return dx

def zeroback(d, Ns):

  Nd = d.shape[0]
  d[np.ix_(range(Nd-Ns,Nd),range(Nd))] = 0;
  d[np.ix_(range(Nd),range(Nd-Ns,Nd))] = 0;

  return d

def return2future(db, idx):
  # restore last of db ows/cols to the idx positions

  Nd = db.shape[0]
  Ns = idx.shape[0]

  idxinv = np.setdiff1d(np.arange(0,Nd),idx)
  idxinvidx = np.concatenate((idxinv, idx)) # new order
  
  dx = np.zeros(db.shape)

  dx[np.ix_(idxinv,idxinv)] = db[np.ix_(range(Nd-Ns),range(Nd-Ns))] # expand rows and cols
  dx[np.ix_(idx,idxinvidx)] = db[np.ix_(range(Nd-Ns,Nd),range(Nd))] # return last rows
  dx[np.ix_(idxinvidx,idx)] = db[np.ix_(range(Nd),range(Nd-Ns,Nd))] # return last cols

  return dx

#### NNMF CuPy version

In [20]:
def ssnmf_cp(De,Dg,Nc,lmbd,k,max_iter):
  
  (m,n) = De.shape

  cp.random.seed(seed=333) 
  A = cp.random.rand(m,k)
  S1 = cp.random.rand(k,n)
  S2 = cp.random.rand(k,n)

  p = m-Nc
  W = cp.zeros((m,n))
  W[cp.ix_(range(Nc),range(Nc))] = 1

  # learinng the base
  e1ant = 1000
  e2ant = 1000
  cont = 0
  maxcont = 10
  tol = 0.001

  l = 0
  cont = 0
  while (l<max_iter) and (cont<maxcont):
    A = A *( De @S1.T+  lmbd*(W *Dg)@S2.T)/((A@S1)@S1.T+  lmbd*(W *(A@S2))@S2.T+0.0000001)
    for u in range(k):
      A[:,u] = A[:,u]/cp.linalg.norm(A[:,u])
    S1 = S1 * ((A.T@De)/(A.T@(A@S1)+0.0000001))
    S2 = S2 * ((A.T@(W*Dg))/(A.T@(W*(A@S2))+0.0000001))
    err1 = cp.linalg.norm(De-A@S1,'fro')
    err2 = cp.linalg.norm(W*(Dg-A@S2),'fro')
    if (2-err1/e1ant-err2/e2ant)<tol:
      cont = cont+1
    else:
      cont = 0
    e1ant = err1
    e2ant = err2
    l = l+1
    #if l%100 == 0: print('l1: ',l)
    #if l==1: time0 = time.time()
    #if l==11: print('l1 10 time', time.time()-time0)

  miter = l        

  # full Dg estimation
  Dr1 = A@S2
  Dr1[cp.ix_(range(Nc),range(Nc,len(Dr1)))] =0
  Dr2 = ((1-W)*Dr1).T + Dr1

  # weigths redefinition
  W = cp.ones((m,n))
  W[cp.ix_(range(Nc,m),range(Nc,n))] = 0

  S2[cp.ix_(range(k),range(Nc,n))] = cp.ones((k,p)); #rand(k,p)
  eant = 1000
  cont = 0
  maxcont = 10
  tol = 0.00001

  l = 0
  cont = 0
  while (l<max_iter) and (cont<maxcont):
    S2 = S2*((A.T@(W*Dr2))/(A.T@(W*(A@S2))+0.0000001))
    err=cp.linalg.norm(W*(Dr2-A@S2),'fro')
    if abs(1-err/eant)<tol:
      cont = cont+1
    else:
      cont = 0
    eant=err
    l = l+1
    #if l%100 == 0: print('l2: ',l)

  aux = A@S2
  aux1 = aux[Nc:,Nc:]
  aux1 = (aux1.T+aux1)/2

  #Dr3 = Dr2
  Dr2[cp.ix_(range(Nc,len(Dr2)),range(Nc,len(Dr2)))] = aux1
  Dr2 = cp.minimum(Dr2,cp.max(Dg)*cp.ones(Dr2.shape))

  return Dr2

#### NNMF numpy version

In [21]:
def ssnmf(De,Dg,Nc,lmbd,k,max_iter):

  (m,n) = De.shape
  np.random.seed(seed=333)
  A = np.random.rand(m,k)
  S1 = np.random.rand(k,n)
  S2 = np.random.rand(k,n)

  p = m-Nc
  W = np.zeros((m,n))
  W[np.ix_(range(Nc),range(Nc))] = 1

  # learning the base
  e1ant = 1000
  e2ant = 1000
  cont = 0
  maxcont = 10
  tol = 0.0001

  l = 0
  cont = 0
  while (l<max_iter) and (cont<maxcont):
    A = A *( De @S1.T+  lmbd*(W *Dg)@S2.T)/((A@S1)@S1.T+  lmbd*(W *(A@S2))@S2.T+0.0000001)
    for u in range(k):
      A[:,u] = A[:,u]/np.linalg.norm(A[:,u])
    S1 = S1 * ((A.T@De)/(A.T@(A@S1)+0.0000001))
    S2 = S2 * ((A.T@(W*Dg))/(A.T@(W*(A@S2))+0.0000001))
    err1 = np.linalg.norm(De-A@S1,'fro')
    err2 = np.linalg.norm(W*(Dg-A@S2),'fro')
    if (2-err1/e1ant-err2/e2ant)<tol:
      cont = cont+1
    else:
      cont = 0
    e1ant = err1
    e2ant = err2
    l = l+1
    if l%100 == 0: print('l1: ',l)
    if l==1: time0 = time.time()
    if l==11: print('l1 10 time', time.time()-time0)

  miter = l        
 
  # full Dg estimation
  Dr1 = A@S2
 
  Dr1[cp.ix_(range(Nc),range(Nc,len(Dr1)))] =0
  Dr2 = ((1-W)*Dr1).T + Dr1

  # weigths redefinition
  W = np.ones((m,n))
  W[np.ix_(range(Nc,m),range(Nc,n))] = 0

  S2[np.ix_(range(k),range(Nc,n))] = np.ones((k,p)); #rand(k,p)
  eant = 1000
  cont = 0
  maxcont = 10
  tol = 0.00001

  l = 0
  cont = 0
  while (l<max_iter) and (cont<maxcont):
    S2 = S2*((A.T@(W*Dr2))/(A.T@(W*(A@S2))+0.0000001))
    err=np.linalg.norm(W*(Dr2-A@S2),'fro')
    if abs(1-err/eant)<tol:
      cont = cont+1
    else:
      cont = 0
    eant=err
    l = l+1
    if l%100 == 0: print('l2: ',l)

  aux = A@S2
  aux1 = aux[Nc:,Nc:]
  aux1 = (aux1.T+aux1)/2

  #Dr3 = Dr2
  Dr2[np.ix_(range(Nc,len(Dr2)),range(Nc,len(Dr2)))] = aux1
  Dr2 = np.minimum(Dr2,np.max(Dg)*np.ones(Dr2.shape))

  return Dr2

#### NNMF-GO

In [22]:
def nnmfgo(d_expr, d_GO, idx_to_rec, lmbd, dim, max_iter):

  # Format GO distance matrix (back and blank genes to annotate)
  dg = send2back(d_GO,idx_to_rec)
  dg = zeroback(dg,len(idx_to_rec))

  # Format expression distance matrix (back genes to annotate)
  de = send2back(d_expr,idx_to_rec)

  # NNMF
  num_genes_annotated = d_expr.shape[0]-len(idx_to_rec)

  # without cupy: 392 s for 10 iterations (aprox 10 h for 400 l1 + 500 l2)
  #rec = ssnmf(d_expr,d_GO,num_genes_annotated,lmbd,dim,max_iter)
  
  # with cupy: 11 s for 10 iterations (aprox 18 min for 400 l1 + 500 l2)
  d_GO_cp = cp.array(dg)  
  d_expr_cp = cp.array(de)
  rec_cp = ssnmf_cp(d_expr_cp,d_GO_cp,num_genes_annotated,lmbd,dim,max_iter)
  rec = cp.asnumpy(rec_cp)
  #rec = cp.asnumpy(d_GO_cp)

  # Restore the original order in the distance matrix
  d_GO_rec = return2future(rec,idx_to_rec)

  return d_GO_rec

# Bayesian inference

In [23]:
def bayesprobs(d,idxB,labelsAwB,ancestors,pexp):
  # ancestors: GO terms for each known gene

  idxA = np.setdiff1d(np.arange(0,len(d)),idxB) # annotated genes
  
  weights = np.zeros((len(idxB),len(labelsAwB))) # % ~> likelihood p(gene|label)
  countAB = np.zeros((len(idxB),len(labelsAwB))) # % ~> prior p(label)

  for genB in range(len(idxB)):
    col = idxB[genB]
    dnorm_col = d[:,col] 
    dnorm_col = (dnorm_col-np.min(dnorm_col))/(np.max(dnorm_col)-np.min(dnorm_col))
    for genA in range(len(idxA)):
      labsA = ancestors.iloc[idxA[genA],:].dropna().to_numpy()
      labsA = np.setxor1d([0],labsA)
      row = idxA[genA]
      d_gen=d[row,col]
      for t in range(len(labsA)):
        idxLab = np.where(labelsAwB == labsA[t])
        weights[genB,idxLab] = weights[genB,idxLab]+(2/(1+d_gen)-1)**pexp 
        countAB[genB,idxLab] = countAB[genB,idxLab]+1

  return (weights, countAB)

In [24]:
def bayesinf(genB,p,cumpcut,labelsAwB):

  prob = p[genB,:]
  prob = prob/np.sum(prob)
  ind = np.argsort(prob)
  indd = ind[::-1] # descendent sort
  cum = np.cumsum(prob[indd])
  pp = np.argmax(cum>cumpcut)
  inB = labelsAwB[indd[:pp+1]]
  
  return (inB, indd, cum)

In [25]:
# separated for otimization
def bayesinf_nocut(genB,p):

  prob = p[genB,:]
  prob = prob/np.sum(prob)
  ind = np.argsort(prob)
  indd = ind[::-1] # descendent sort
  cum = np.cumsum(prob[indd])
  
  return indd, cum

def bayesinf_justcut(cum, cumpcut, indd, labelsAwB):

  pp = np.argmax(cum>cumpcut)
  inB = labelsAwB[indd[:pp+1]]
  
  return inB

def simple_cut(score, cumpcut, labels):

  pp = np.argmax(score>cumpcut)
  inB = labels[pp]
  
  return inB

# Main

#### Performance measures

In [26]:
def spf1(refP,refN,predP,labels,remove_labels=[]):
  if len(remove_labels)>0:
    refP = np.setdiff1d(refP,remove_labels)
    predP = np.setdiff1d(predP,remove_labels)

  predN = np.setdiff1d(labels,predP)
  TP = len(np.intersect1d(refP,predP))
  TN = len(np.intersect1d(refN,predN))
  FP = len(np.intersect1d(predP,refN))
  FN = len(np.intersect1d(predN,refP))
  sens = TP/(TP+FN)
  if TP+FP>0: prec = TP/(TP+FP)
  else: prec = 0
  if sens+prec>0: F1 = 2*sens*prec/(sens+prec)
  else: F1 = 0
  return (TP, TN, FP, FN, sens, prec, F1)

### Parameters

In [27]:
lmbd = 0.001      #@param {type:"slider", min:0, max:200000, step:0.05}
dim = 80        #@param {type:"slider", min:10, max:2000, step:10}
pexp = 6        #@param {type:"slider", min:1, max:15, step:1}
max_iter = 500  #@param {type:"slider", min:1, max:5000, step:100}
tol = 0.001     #@param {type:"slider", min:0, max:1, step:0.001}

gamma = 1.0     #param {type:"slider", min:0, max:1, step:0.05}


In [28]:
Ngenes = Dg.shape[0]
Nlabels = len(labels)

d_expr = De.to_numpy()
d_GO = Dg.to_numpy()    # TO-DO clear copies

indx_to_predictLK = np.where(De.index.isin(LK))
indx_to_predictNK = np.where(De.index.isin(NK))
indx_to_predict = np.concatenate((indx_to_predictLK[0], indx_to_predictNK[0]))
indx_unknown = indx_to_predictNK[0]

d_GO_rec = nnmfgo(d_expr, d_GO, indx_unknown, lmbd, dim, max_iter)
d_GO_final = d_GO_rec

In [29]:
#d_GO_final = Dg.fillna(2.0).to_numpy() # <<< oracle

(weights, countAB) = bayesprobs(d_GO_final,indx_to_predict,labels,ancT_1,pexp)
p = np.multiply(weights,countAB)

In [30]:
Ngenes_to_predict = len(indx_to_predict)
Max_labs_x_gen = len(ancT0.iloc[0,:].dropna().values)

indord = np.zeros((Ngenes_to_predict, Nlabels), dtype=np.uint32)
cum =  np.zeros((Ngenes_to_predict, Nlabels))
refP = np.zeros((Ngenes_to_predict, Max_labs_x_gen), dtype=np.uint32)
refN = np.setdiff1d(labels,refP)
remove_labels = np.zeros_like(refP)

for k, igenloo in enumerate(indx_to_predict):
  indord[k,:], cum[k,:] = bayesinf_nocut(k, p)
  refP[k,:] = ancT0.iloc[igenloo,:].dropna().values
  remove_labels[k,:] = ancT_1.iloc[igenloo,:].dropna().values

Nth = 100
cum_ths = np.linspace(0.0+1.0/Nth, 1.0, num=100)
F1th = np.zeros((Nth,))
F1 = np.zeros((len(cum_ths),Ngenes_to_predict))
for ith, ccut in enumerate(cum_ths):
  for k in range(Ngenes_to_predict):
    labelsB_genB = bayesinf_justcut(cum[k,:], ccut, indord[k,:], labels)
    _, _, _, _, _, _, F1[ith,k] = spf1(refP[k,:],refN,labelsB_genB,labels)
    F1th[ith] += F1[ith,k]
  F1th[ith] = F1th[ith]/len(indx_to_predict)

print(F1th)
print('F1max:', max(F1th))
print('pcut: ', cum_ths[np.argmax(F1th)])

print(np.mean(F1[np.argmax(F1th),:]))
print(np.std(F1[np.argmax(F1th),:]))

[0.1118086  0.1118086  0.1118086  0.1118086  0.1118086  0.1118086
 0.1118086  0.1118086  0.1118086  0.1118086  0.1118086  0.1115723
 0.12208287 0.1258178  0.13373    0.14294061 0.14575123 0.14688891
 0.15320068 0.15298568 0.16018782 0.15988396 0.16279916 0.16737895
 0.17375758 0.18043667 0.18453195 0.18553716 0.18892728 0.19398283
 0.19771851 0.20605468 0.20895048 0.20883668 0.21603988 0.22004259
 0.21928377 0.22075142 0.22209638 0.22139968 0.22532235 0.23089486
 0.2337745  0.23389761 0.2322342  0.23372058 0.23741988 0.23799034
 0.24131618 0.24695773 0.24728802 0.25042746 0.25831375 0.26069552
 0.26386018 0.26802789 0.27098703 0.27776639 0.27888192 0.28419468
 0.2875345  0.28976697 0.28963179 0.29201037 0.29395208 0.29599336
 0.29426841 0.29202839 0.29418354 0.29309312 0.29353691 0.29594852
 0.29776604 0.30386379 0.30367943 0.30404533 0.30407094 0.30692645
 0.305985   0.30730099 0.30443264 0.3019172  0.30121211 0.30104054
 0.30011173 0.29807391 0.29406665 0.29000926 0.28568313 0.278948