In [None]:
import pandas as pd 
import numpy as np
import glob
from tqdm import tqdm 
import numba as nb
from numba import njit, jit

import editdistance 
import torch 
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from seq_models import SeqCNN
from seqgen import editDist
from seqio import FastaFile
from config import config 


import matplotlib.pyplot as plt
import plotly.express as ex 

import warnings
warnings.filterwarnings('ignore')

In [None]:
def cosine_dist(a, b):
    return 1- np.matmul(a,b)/np.linalg.norm(a)/np.linalg.norm(b)


def cosine_sim(a, b):
    return np.matmul(a,b)/np.linalg.norm(a)/np.linalg.norm(b)


def load_pretrained_net( groups, kernel, num_layers, stride=2,in_channels=4, channels=1, i=0):
    paths = glob.glob(f"{config['networks_dir']}/*_in_channels{in_channels}_num_layers{num_layers}_channels{channels}_kernel{kernel}_stride{stride}_groups{groups}_*")
    print(paths)
    net = SeqCNN(in_channels=in_channels, groups=groups, kernel=kernel, num_layers=num_layers, stride=stride, channels=channels)
    net.load_state_dict(torch.load(paths[i]))
    return net


# def load_best_net(layers):
#     best_params = {
#         8: {}
#     }


def load_ed_viral():
    df = pd.read_csv(config['viral_ed'],names=['seq1','seq2', 'ed'])
    df.seq1 = df.seq1.apply(lambda x: x.replace('.gz',''))
    df.seq2 = df.seq2.apply(lambda x: x.replace('.gz',''))
    seq_names = list(df.seq1.unique()) + list(df.seq2.unique())
    seq_names = list(set(seq_names))
    return df, seq_names


def save_viral_data():
    paths = glob.glob(f"{config['viral_dir']}/*.fna")
    fastafiles = [FastaFile(p) for p in tqdm(paths,total=len(paths))]

    names = []
    ids = []
    seqs = []
    lens = []
    for ff in fastafiles:
        for s in ff.seqs:
            names.append(ff.name)
            seqs.append(s.seq)
            ids.append(s.id)
            lens.append(len(s.seq))
    np.savez('viral',name=names,seq=seqs,id=ids, len=lens)

    processed = np.isin(names,seq_names)
    

class ViralDataset(Dataset):
    def __init__(self, L, stride, samples=0):
        super().__init__()
        data = np.load('viral.npz',allow_pickle=True)
        self.seqs = data['seq']
        lens = data['len']
        if samples > 0:
            self.seqs = self.seqs[lens<samples]
        self.L = L
        self.alph = 4
        self.stride = stride
        self.sid, self.pos = [], []
        for si,s in enumerate(self.seqs):
            for i in range(L,len(s),stride):
                self.pos.append(i-L)
                self.sid.append(si)
                
    def get_seq(self, i):
        si, idx = self.sid[i], self.pos[i]
        return self.seqs[si][idx:idx+self.L]
        

    def __len__(self):
        return len(self.sid)
    
    def __getitem__(self, i): 
        s = self.get_seq(i)
        X = torch.from_numpy(s).type(torch.int64)
        X = F.one_hot(X, num_classes=self.alph)
        X.transpose_(0,1)
        X = X.float()
        return X, (self.sid[i], self.pos[i])

    
def embed_seqs(net, dataset, L, stride):
    loader = DataLoader(dataset=dataset, batch_size=256, shuffle=False)
    net = net.to(device)

    embeddings = []
    sids = []
    pos = []
    for data, (si,idx) in tqdm(loader,total=len(loader)):
        data = data.to(device)
        embed = net(data)
        embeddings.append(embed.cpu().data.numpy())
        sids.append(si)
        pos.append(idx)

    embeddings = np.concatenate(embeddings)
    embeddings = embeddings.squeeze()
    sids = np.concatenate(sids)
    pos = np.concatenate(pos)
    return embeddings, sids, pos   


def nearest_neighbors(embeddings, ensembles=3, leaf=10, dim_expand=2):
    N, din = embeddings.shape
    dout = int(np.log2(N)*dim_expand)

    # project by Gaussian matrix & take the sign bit
    P = np.random.randn(din,dout)
    embed2 = np.matmul(embeddings, P)
    embed2 = (embed2>0).astype(np.int32)
    embed2 = embed2.transpose()

    # lexical sort after random column permutation
    indices = np.zeros((N,ensembles),dtype=np.int32)
    for j in range(ensembles):
        embed2 = embed2[np.random.permutation(dout),:]
        indices[:,j] = np.lexsort(embed2)
    indices = indices.transpose()

    neighbors = np.zeros((N,leaf * ensembles), np.int64)
    for e in range(ensembles):
        print(f"ensemble {e}")
        for i in tqdm(range(N),total=N):
            for j in range(leaf):
                u, v = indices[e,i], indices[e,(i+j)%N]
                idx = j + e*leaf
                neighbors[u, idx] = v
                neighbors[v, idx] = u
                
    return neighbors, indices 


# new projection for every ensemble 
def nearest_neighbors2(embeddings, ensembles=3, leaf=10, dim_expand=1):
    N, din = embeddings.shape
    dout = int(np.log2(N)*dim_expand)

    # lexical sort after random column permutation
    indices = np.zeros((N,ensembles),dtype=np.int32)
    print('projection ')
    for j in tqdm(range(ensembles),total=ensembles):
        P = np.random.randn(din,dout)
        embed2 = np.matmul(embeddings, P)
        embed2 = (embed2>0)
        embed2 = embed2.transpose()
        indices[:,j] = np.lexsort(embed2)
    indices = indices.transpose()

    neighbors = np.zeros((N,leaf * ensembles), np.int64)
    for e in range(ensembles):
        print(f"ensemble {e}")
        for i in tqdm(range(N),total=N):
            for j in range(leaf):
                u, v = indices[e,i], indices[e,(i+j)%N]
                idx = j + e*leaf
                neighbors[u, idx] = v
                neighbors[v, idx] = u
                
    return neighbors, indices 


# new projection for every ensemble 
def nearest_neighbors3(embeddings, ensembles=3, leaf=10, dim_expand=1):
    N, din = embeddings.shape
    dout = int(np.log2(N)*dim_expand)

    # lexical sort after random column permutation
    indices = np.zeros((N,ensembles),dtype=np.int64)
    print('projection ')
    for j in tqdm(range(ensembles),total=ensembles):
        P = np.random.randn(din,dout)
        embed2 = np.matmul(embeddings, P)
        embed2 = (embed2>0)
        embed2 = embed2.transpose()
        indices[:,j] = np.lexsort(embed2)
    indices = indices.transpose()
                
    return indices 


def nearest_neighbors4(embeddings, ensembles=3, leaf=10, dim_expand=2):
    N, din = embeddings.shape
    dout = int(np.log2(N)*dim_expand)

    # project by Gaussian matrix & take the sign bit
    P = np.random.randn(din,dout)
    embed2 = np.matmul(embeddings, P)
    embed2 = (embed2>0).astype(np.int32)
    embed2 = embed2.transpose()

    # lexical sort after random column permutation
    indices = np.zeros((N,ensembles),dtype=np.int32)
    for j in range(ensembles):
        embed2 = embed2[np.random.permutation(dout),:]
        indices[:,j] = np.lexsort(embed2)
    indices = indices.transpose()
    return indices, embed2.transpose()

In [None]:
def get_pairs(indices, leaf):
    X = [np.stack((ind,np.roll(ind,j)),axis=1) for j in range(1,leaf+1) for ind in indices]
    X = np.concatenate(X)
    return X

@njit
def get_rank(array):
    temp = array.argsort()
    ranks = np.empty_like(temp)
    ranks[temp] = np.arange(len(array))
    return ranks

@njit 
def get_ranks(matrix):
    ranks = np.empty_like(matrix)
    for j in range(ensembles):
        ranks[j] = get_rank(matrix[j])

            
@njit
def pbar(i, total, l=100):
    if i % (int(total/l))==0:
        print(int(100*i/total))


@njit
def get_dist_approx(X, emb):
    N, D = X.shape[0], emb.shape[1]
    d = np.zeros(N)
    for r,(i,j) in enumerate(X):
        pbar(r,N)
        d[r] = 1-np.sum(emb[i,:]*emb[j,:])/D
    return d  


@njit 
def cosine_dist(a, b):
    a = a / np.linalg.norm(a)
    b = b / np.linalg.norm(b)
    return 1-np.sum(a*b)


@njit
def get_dist_exact(X, emb, l=100):
    N, D = X.shape[0], emb.shape[1]
    d = np.zeros(N,dtype=np.float64)
    for r,(i,j) in enumerate(X):
        pbar(r,N, l)
        d[r] = 1-cosine_dist(emb[i],emb[j])
    return d  

# Embed sequences using pre-trained model 

In [None]:


num_layers = 8
kernel=3
groups=4
stride_ratio = 1./4
device = 'cuda'

L = 2**num_layers
stride = int(L*stride_ratio)
dataset = ViralDataset(samples=2000, L=L, stride=stride)
net = load_pretrained_net(groups=groups, kernel=kernel, num_layers=num_layers,)
embeddings, sids, pos = embed_seqs(net, dataset, L, stride)
np.savez('embed',embeddings=embeddings, sids=sids, pos=pos)
embeddings.shape, sids.shape, pos.shape

In [None]:
data = np.load('embed.npz')
embeddings, sids, pos = data['embeddings'], data['sids'], data['pos']
indices = np.load('indices.npz',allow_pickle=True)['indices']

embeddings.shape, sids.shape, pos.shape, indices.shape

# Nearest neighbors in embedding space 

In [None]:
%%time 
            
ensembles = 10
leaf = 100
dim_expand = 5

indices = nearest_neighbors3(embeddings, ensembles=ensembles, leaf=leaf, dim_expand=dim_expand)
np.savez('indices',indices=indices)

indices.shape

In [None]:
%%time

X = get_pairs(indices=indices, leaf=leaf)

ED = get_dist_exact(X, embeddings, 10)

In [None]:
sorted_idx = np.argsort(ED)

In [None]:
v2s = lambda x: "".join([chr(a+ord('a')) for a in x])
num_samples = 1000
x=np.empty(num_samples)
y = np.empty(num_samples)
meta = [None]*num_samples
for ri in range(num_samples):
    r = np.random.randint(num_samples)
    si = sorted_idx[r]
    i, j = X[si,:]
    s1, s2 = dataset.get_seq(i), dataset.get_seq(j)
    d = editdistance.eval(s1, s2)/L
    s1, s2 = v2s(s1), v2s(s2)
    meta[ri] = f"{s1[:10]},\n {s2[:10]}"
    x[ri], y[ri] = ED[si], d

In [None]:
df = pd.DataFrame({'embed dist': x, 'edit dist': y, 'meta': meta})
ex.scatter(df, x='edit dist', y='embed dist', hover_name='meta' )

In [None]:
ex.line(ED[sorted_idx[::1000]])

## TSNE plot 

In [None]:
from sklearn.manifold import TSNE
embed_low = TSNE(n_components=3).fit_transform(embeddings)
embed_low.shape 

labels = [ids[sid].split(' ')[0] for sid in sids]
ex.scatter_3d(embed_low, x=0, y=1, z=2, color=labels)