In [3]:
#Dependencies
import numpy as np
import cPickle as pickle
import glob
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons
#Useful functions
def imshow3d(Im, axis=0, **kwargs):
    """
    Display a 3d ndarray with a slider to move along the 0th dimension.
    Extra keyword arguments are passed to imshow
    """
    im = np.array(Im)
    # generate figure
    f, ax = plt.subplots()
    f.subplots_adjust(left=0.25, bottom=0.25)
    # select first image
    s = [slice(0, 1) if i == axis else slice(None) for i in xrange(3)]
    im_ = im[s].squeeze()
    # display image
    l = ax.imshow(im_, **kwargs)
    l.set_clim(vmin=np.min(im),vmax=np.max(im))
    # define slider
    axcolor = 'lightgoldenrodyellow'
    ax = f.add_axes([0.25, 0.1, 0.65, 0.03], axisbg=axcolor)
    slider = Slider(ax, 'Axis %i index' % axis, 0, im.shape[axis] - 1,
                    valinit=0, valfmt='%i')
    def update(val):
        ind = int(slider.val)
        s = [slice(ind, ind + 1) if i == axis else slice(None)
                 for i in xrange(3)]
        im_ = im[s].squeeze()
        l.set_data(im_)
        f.canvas.draw()
    slider.on_changed(update)
    plt.show()

def chromosomes(nuc_dia=10000,pixel_sz=100,plt_val=False):
    """
    nuc_dia is the nuclear diameter in nm
    pixel_sz is in nm
    This assumes 46 chromosomes
    
    Return list of chromosomes and pixels in their territory
    """
    #coordinates for 46 sphere centers
    #(equal size spheres in a unit sphere)
    #see:https://oeis.org/A084827/a084827.txt
    centers=[[-0.127724638717686,0.029283782782012,-0.763670872459570], 
            [0.302116854275886,0.146601789724809,-0.698281876003332], 
            [0.050116071438789,-0.375084565347080,-0.676139240788969], 
            [0.387404648096449,-0.300279722464142,-0.600095035492607],
            [-0.221565702064757,0.438003368581342,-0.599521487098418],
            [-0.536838502467010,0.121012629438513,-0.545458207564384],
            [0.470578557122151,-0.324839673302964,-0.522876020618661],
            [0.206821475639773,0.544767478949537,-0.510703040725137],
            [0.647737208453552,0.075787428022586,-0.418398254359731],
            [0.209291510617636,-0.653452063989750,-0.359946924370349],
            [-0.240428762326608,-0.655246890184877,-0.336466711372591],
            [0.027563278735129,0.169874066797150,-0.337139524778479],
            [-0.531122333361574,0.491550397468556,-0.276860250786947],
            [-0.125040038594464,0.718782537235944,-0.260923317520113],
            [-0.028222635427186,-0.267579430698296,-0.245896798982907],
            [0.559897837805783,0.479367416697336,-0.238925962888257],
            [-0.609344934400770,-0.421155893776354,-0.227356083644822],
            [-0.755792906627536,0.000918343779410,-0.170705973387576],
            [0.709453517788630,-0.276107684781292,-0.144237918782831],
            [0.338406350902039,-0.029318746498438,-0.079260341210368],
            [0.256184770042010,0.730938689442354,-0.021501641508632],
            [-0.268046158037773,0.223179830668424,-0.001615424109930],
            [0.463839024087979,-0.620577043697123,0.010090454994701],
            [0.761425580114896,0.142996856131315,0.012137124700828],
            [0.041055031342583,-0.772687639260906,0.040405708106847],
            [-0.343201070932800,-0.214763803705687,0.071596445689072],
            [-0.392969757022585,-0.662069840802751,0.087193008193199],
            [-0.377886422912343,0.667723934061050,0.108217022567140],
            [-0.686352373667351,0.339757482368351,0.117684310970756],
            [0.150619047600183,0.321066162828993,0.132327016008240],
            [0.137964450619487,-0.350718167453077,0.164313718413543],
            [0.559387984377712,0.492787670746059,0.211210130456054],
            [-0.717576062734593,-0.078536494382680,0.281568709115817],
            [0.643403410008865,-0.310581345960640,0.299892559603968],
            [0.002276767510746,0.692083481917933,0.348395549284496],
            [-0.069193117297735,-0.000826838519097,0.357871631431749],
            [-0.074584688024342,-0.626168415760149,0.450238341469810],
            [0.622296753862575,0.114447785021264,0.447227819362128],
            [-0.471682318226388,-0.413806749821993,0.454581223971127],
            [-0.434951569989064,0.423001400164857,0.481924550044267],
            [0.305007962363991,-0.417373667885278,0.577177346197253],
            [0.295340191120549,0.432541638100190,0.571004577504622],
            [-0.446844519125231,-0.001070128504388,0.633003282191707],
            [-0.094303907267779,-0.267030401770297,0.721225250748454],
            [0.281485705865138,0.008444506916010,0.721844036069634],
            [-0.091709170433872,0.260484226789782,0.723948700091940]]

    centers = np.array(centers)
    
    arr_size = nuc_dia/pixel_sz #division casts as int
    x_ = np.linspace(-1,1,arr_size)
    chrters = [[] for i in range(len(centers))]

    for x in x_:
        for y in x_:
            for z in x_:
                #test if in sphere
                if x*x+y*y+z*z<=1:
                    chr_index = np.argmin(np.sum((centers-[[x,y,z]])**2,axis=1))#compute the closest index to current xyz point
                    chrters[chr_index].append([x,y,z])
    if plt_val:
        im = np.zeros([arr_size]*3)
        for i,chr_ in enumerate(chrters):
            for x,y,z in (np.array(chr_)+1)*(arr_size-1)/2:
                im[int(np.round(x)),int(np.round(y)),int(np.round(z))]=i+1
        imshow3d(im,interpolation='nearest')
    return chrters

def TAD_blur(xyzPos,pix_sz=100,nuc_dia=10000):
    perturb=np.random.normal(0,pix_sz/2./(nuc_dia/2.),3)
    return perturb+xyzPos

def TAD_generator(xyzChr,noTADs=100,udist=-0.44276236166846844,sigmadist=0.57416477624326434,nuc_dia=10000,pix_sz=100):
    """
    xyzChr is a list of positions belonging to a chromosome territory
    Returns an array of dimensions noTADSx3
    """
    xyzChr_=np.array(xyzChr)
    tads=[]
    first=xyzChr_[np.random.randint(1,len(xyzChr))] #randomly choose location of first TAD
    first=TAD_blur(first)
    tads.append(first)
    for i_tad in range(noTADs-1):
        difs=xyzChr_-[tads[i_tad]]#unit radius
        dists=np.sqrt(np.sum(difs**2,axis=-1))
        dists=np.log(dists*nuc_dia/2000.)#unit log um
        weights = np.exp(-(dists-udist)**2/(2*sigmadist**2))
        weights = np.cumsum(weights)
        weights = weights/float(np.max(weights))
        index_pj = np.sum(np.random.rand()-weights>0)
        pj=xyzChr_[index_pj]#unit radius
        pj=TAD_blur(pj)
        tads.append(pj)
    return np.array(tads)

In [57]:
chrters=chromosomes()

In [80]:

arr_size=100
im = np.zeros([arr_size]*3)
for i,chr_ in enumerate(chrters):
    for x,y,z in (np.array(chr_)+1)*(arr_size-1)/2:
        im[int(np.round(x)),int(np.round(y)),int(np.round(z))]=i+1
x,y,z = TAD_generator(chrters[0]).T
plt.imshow(np.max(im==1,axis=-1))
plt.plot((y+1)*(arr_size-1)/2,(x+1)*(arr_size-1)/2,'-o')

plt.show()

In [89]:
reals=100 #number of realizations
tads=100 #number of TADs per chromosome
real_matrix=[]
for i_rel in range(reals):
    single_cell=[]
    for chrter in chrters:
        tads_=TAD_generator(chrter,tads)
        single_cell.append(tads_)
    real_matrix.append(single_cell)
real_matrix=np.array(real_matrix)
import cPickle as pickle
pickle.dump(real_matrix,open('simulatedTads.pkl','wb'))

In [5]:
real_matrix = pickle.load(open('simulatedTads.pkl','r'))

In [32]:
real_matrix = real_matrix*5000

In [26]:
for i,chr_ in enumerate(real_matrix[0]):
    x,y,z = chr_.T
    #plt.text()
    plt.plot(x,y,'wo')
for i,chr_ in enumerate(real_matrix[0]):
    x,y,z = chr_.T
    if i>5:
        break
    #plt.text()
    plt.plot(x,y,'o')
plt.axis('equal')
plt.show()

In [149]:
#Encoder - construct a matrix hybes of length number of hybes x number of chromosomes 
#each containing the id of the tad in the hybe (0 means the TAD is missing from that hybe)
nreal,nchr,ntads,ndim = real_matrix.shape
tad_buckets = [range(ntads) for i in range(nchr)]
perc_label = 0.5
chr_labeled =int(perc_label*nchr)
hybes=[]
while np.sum(map(len,tad_buckets))>0:
    lens = map(len,tad_buckets)
    inds = np.argsort(lens)[::-1]
    ind_select = inds[:chr_labeled]
    ind_select = [ind for ind in ind_select if lens[ind]>0]
    
    hybe = np.zeros(nchr)
    for ind in ind_select:
        hybe[ind]=tad_buckets[ind][0]+1
        tad_buckets[ind].pop(0)
    hybes.append(hybe)
hybes=np.array(hybes,dtype=int)

In [172]:
#Given hybes(encoder matrix) and cell(truth positions for single cell) simulate hybe data
cell = real_matrix[0]

hybes_points=[]
for hybe in hybes:
    chrs_in_hybe = np.where(hybe>0)[0]
    tad_ids_in_hybe = hybe[hybe>0]-1
    hybe_points=[]
    for chr_in_hybe,tad_in_hybe in zip(chrs_in_hybe,tad_ids_in_hybe):
        hybe_points.append(cell[chr_in_hybe][tad_in_hybe])
    hybes_points.append(hybe_points)
hybes_points = np.array(hybes_points)

In [173]:
#Decoder - Given hybes_points, predict their tad id and 
hybes_points.shape

(200L, 23L, 3L)

In [187]:
id_ref = 0
hybes_points_ref = hybes_points[id_ref]
point = hybes_points_ref[10]
min_L1_dists=[]
for hybe_point in hybes_points:
    difs = point - hybe_point
    min_L1_dist = np.min(np.sum(np.abs(difs),axis=-1))
    min_L1_dists.append(min_L1_dist)


In [None]:
possible_projections = 

In [194]:

possible_chrs_hybes=[]
for hybe in hybes:
    possible_chrs_hybes.append(np.where(hybe>0)[0])
possible_chrs = np.where(hybes[id_ref]>0)[0]
possible_projections = np.zeros([len(possible_chrs),len(possible_chrs_hybes)],dtype=int)
for i,chr_T in enumerate(possible_chrs):
    for j,possible_chrs_hybe in enumerate(possible_chrs_hybes):
        possible_projections[i,j]=chr_T in possible_chrs_hybe

In [195]:
plt.imshow(possible_projections)
plt.show()

array([[1, 0, 1, ..., 0, 1, 0],
       [1, 0, 1, ..., 0, 1, 0],
       [1, 0, 1, ..., 0, 1, 0],
       ..., 
       [1, 0, 1, ..., 0, 1, 0],
       [1, 0, 1, ..., 0, 1, 0],
       [1, 0, 1, ..., 0, 1, 0]])

In [142]:
cell = real_matrix[9]
wrong = []
for cell in real_matrix:
    chr_id = 10
    ref_chr = cell[chr_id]
    estimator = []
    for ref_point in ref_chr:
        dist_ref = np.sqrt(np.sum((ref_chr - [ref_point])**2,axis=-1))
        chr_dists=[]
        for chr_ in cell:
            dist = np.sqrt(np.sum((chr_ - [ref_point])**2,axis=-1))
            chr_dists.append(np.mean(dist))
        estimator.append(np.argmin(chr_dists))
    wrong.append(np.sum(np.array(estimator) != chr_id))

In [143]:
np.mean(wrong)

8.3300000000000001

In [114]:
ntads/0.66

151.5151515151515

In [110]:
tad_buckets = [range(ntads) for i in range(nchr)]
np.array(tad_buckets).shape

(46L, 100L)

In [94]:
for ind in ind_select:
    hybe[ind]=tad_buckets[ind][0]

In [95]:
hybe

array([ 0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  1.,  1.,  1.,  1.,  1.,  1.])

In [67]:
nreal,nchr,ntads,ndim = real_matrix.shape
tad_buckets = [range(ntads)for i in range(nchr)]
lens = map(len,tad_buckets)
lens

[100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100]

In [55]:
from scipy.spatial.distance import pdist
plt.hist(pdist(real_matrix[0][0]),normed=True)
plt.hist(np.sqrt(np.sum(np.diff(real_matrix[0][0],axis=0)**2,axis=-1)),normed=True)
plt.show()

In [54]:
np.sqrt(np.sum(np.diff(real_matrix[0][0],axis=0)**2,axis=-1))

array([  496.23233286,  1631.3781668 ,  1258.90280477,   529.43926965,
         478.6365907 ,  1224.65537685,  1113.29715536,  1249.77961976,
         369.20484667,   542.05902796,   863.32262967,   858.66085439,
        1198.17897431,   715.25088767,   945.35524174,   315.81670009,
         752.38953997,   679.51192949,  1181.50050014,  1692.37798226,
        1044.30639416,  2590.14482932,  1670.49205492,   766.00304014,
         714.80749687,  1729.93944056,   949.0988148 ,   505.02013097,
         715.94803174,  1781.85020538,  1232.47749631,   597.67278328,
        1647.11208878,   765.92950609,  1268.03832704,  1396.19222885,
        1069.17597861,  1514.1645303 ,  1401.42728274,  1477.05059185,
         819.33151993,   523.76969171,   899.90149512,   837.78076201,
        1355.0347288 ,   947.27669953,   982.64566123,  1985.73240087,
        1702.32779502,  1837.26929096,  1136.65578241,  1169.85686579,
        1108.40533374,  1281.81020626,  1088.8933316 ,   930.00565464,
      

In [44]:
import glob
import numpy as np
#files = glob.glob('*.csv')
files= ['chr21.csv', 'chr22.csv']
file_ =files[0]
def file_to_mat(file_):
    lines = [ln for ln in open(file_,'r')]
    def refine_line(ln):
        splits = ln[:-1].split(',')
        return [np.nan if ln_=='' else float(ln_)for ln_ in splits]
    lines = map(refine_line,lines[1:])
    return np.array(lines)
def data_to_dists(data):
    icell_prev=np.nan
    iTAD_prev=np.nan
    dists = []
    for icell,iTAD,x,y,z in data:
        if icell_prev==icell:
            xyz = np.array([x,y,z])
            dist = np.sqrt(np.sum((xyz-xyz_prev)**2))
            dists.append(dist)
        icell_prev=icell
        xyz_prev = np.array([x,y,z])
    dists = np.array(dists)
    dists = dists[np.isnan(dists)==False]
    return dists

dists0 = data_to_dists(file_to_mat(files[0]))
dists1 = data_to_dists(file_to_mat(files[1]))
dists = np.concatenate([dists0,dists1])


In [43]:
import matplotlib.pylab as plt
#plt.hist(np.log(dists0),bins=30,alpha=0.7)
#plt.hist(np.log(dists1),bins=30,alpha=0.7)
#plt.hist(dists0,bins=30,alpha=0.7)
#plt.hist(dists1,bins=30,alpha=0.7)
plt.hist(dists,bins=40)
plt.show()

In [49]:
np.std(dists)

0.45172828742659688

In [48]:
np.mean(dists)

0.75397658663865952