In [1]:
import pytraj as pt
import pytraj.utils.progress
import numpy as np
import scipy as sp
import matplotlib
from matplotlib import pyplot as plt
import ggplot
import collections
import sys
import gc
import os
import sklearn as skl
from sklearn import decomposition
from sklearn import metrics
from sklearn import discriminant_analysis
from sklearn import cluster
import tqdm
import nglview as nv
import ipywidgets
import copy

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

This notebook uses the lipid center of mass coordinate data generated by the 
'extract_membrane_headgroups.ipynb' notebook to predict which leaflet each lipid resdiue
belonged too. The resulting data will be used to construct height and density maps for
individual leaflets.

Since the data files are too large to store as single files on github, we will use the
functions defined in 'extract_membrane_headgroups.ipynb' to load and save datafiles piecewise

In [2]:
def saveArrayChunks(pathBase,arr,nChunks,axis=0,
                    pbar=None):
    """
        pathBase: the prefix of the file path to save each chunk to.
                    files will be named pathBase.chunk_#.npy, where # is
                    a zero padded integer (to make loading, sorting, etc easier)
        arr: the array to be saved
        axis: the axis along which to split the array ()
    """
    arrayChunks=np.array_split(arr,nChunks,axis=axis)
    ndigits=int(np.ceil(np.log10(nChunks)))
    digitStr='%'+'0%g'%ndigits+'g'
    if not pbar is None:
        pbar.n=len(arrayChunks)
        pbar.refresh()
    for iChunk,arrayChunk in enumerate(arrayChunks):
        outPath='.'.join([pathBase,'chunk_%s'%(digitStr%iChunk),'npy'])
        np.save(outPath,arrayChunk)
        if not pbar is None:
            pbar.update()
            
def loadArrayChunks(pathBase,nChunks,axis=0,
                    pbar=None):
    arrayChunks=[]
    ndigits=int(np.ceil(np.log10(nChunks)))
    digitStr='%'+'0%g'%ndigits+'g'
    if not pbar is None:
        pbar.n=len(arrayChunks)
        pbar.refresh()
    for iChunk in np.arange(nChunks):
        dataPath='.'.join([pathBase,'chunk_%s'%(digitStr%iChunk),'npy'])
        arrayChunks.append(np.load(dataPath))
        if not pbar is None:
            pbar.update()
    return np.concatenate(arrayChunks,axis=axis)

Now we can load the center of mass data set for each system.

In [8]:
dataFileDir='dataFiles'
comDataDir='/'.join([dataFileDir,'headgroupCoords'])
leafletClusteringDir='/'.join([dataFileDir,'leafletClustering'])

comFileTypeName='headgroup_COM_coords'

systems=['POPC','POPS','PIP2']

nChunks=4

comDataDict={}
print 'Loading data sets ',
with tqdm.tqdm_notebook() as pbar:
    for system in systems:
        print system,
        pbar.set_description_str(system)
        comFileNameBase='.'.join([system,comFileTypeName])
        comFilePathBase='/'.join([comDataDir,comFileNameBase])
        comDataDict[system]=loadArrayChunks(comFilePathBase,nChunks=nChunks,axis=1,
                                            pbar=pbar)
        gc.collect()
    print ''
print 'done loading data'
print '--- --- --- ---'

for setKey in comDataDict:
    print setKey,
    print comDataDict[setKey].shape

Loading data sets 

HBox(children=(IntProgress(value=1, bar_style=u'info', max=1), HTML(value=u'')))

 POPC POPS PIP2 

done loading data
--- --- ---
POPC (1176, 2001, 3)
POPS (1282, 1592, 3)
PIP2 (1290, 1592, 3)


We now have the center of mass data for each system. Before diving straight into clustering,
lets have a look at them visually.

In [21]:
systemWidget=widgets.Dropdown(
                 options=comDataDict.keys(),
                 description='System: ')
frameWidget=widgets.IntSlider(description='Frame:',
                              continuous_update=False)

def updateFrameRange(*args):
    frameRange=[0,comDataDict[systemWidget.value].shape[1]-1]
    frameWidget.min=frameRange[0]
    frameWidget.max=frameRange[1]
    frameWidget.value=np.clip(frameWidget.value,
                              frameRange[0],
                              frameRange[1])
frameWidget.observe(updateFrameRange)

def plotGrids(systemName,frameNumber):
    frame=np.clip(
            frameNumber,
            0,
            comDataDict[systemName].shape[1]-1)
    comData=comDataDict[systemName][
        :,
        frame,
        :]
    fig,axs=plt.subplots(1,3)
    fig.set_figheight(4)
    fig.set_figwidth(12)
    
    ax=axs.flat[0]
    ax.scatter(comData[:,0],comData[:,1])
    ax.set_title('%s: XY, frame %g'%(systemName,frame))
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    
    ax=axs.flat[1]
    ax.scatter(comData[:,0],comData[:,2])
    ax.set_title('%s: XZ, frame %g'%(systemName,frame))
    ax.set_xlabel('X')
    ax.set_ylabel('Z')
    
    ax=axs.flat[2]
    ax.scatter(comData[:,1],comData[:,2])
    ax.set_title('%s: YZ, frame %g'%(systemName,frame))
    ax.set_xlabel('Y')
    ax.set_ylabel('Z')
    
    plt.tight_layout()
    plt.show()
    
interact(plotGrids,
         systemName=systemWidget,
         frameNumber=frameWidget)

aW50ZXJhY3RpdmUoY2hpbGRyZW49KERyb3Bkb3duKGRlc2NyaXB0aW9uPXUnU3lzdGVtOiAnLCBvcHRpb25zPSgnUE9QQycsICdQT1BTJywgJ1BJUDInKSwgdmFsdWU9J1BPUEMnKSwgSW50U2zigKY=


<function __main__.plotGrids>