In [2]:
import pickle

class CollectForestInfo:
    intermediateDict = None
    residualDict = None
    descendant_dict = None
    repCommMotifSeq_dict = None # save tree's common motif seq. list
    treeList = None
    
    def __init__(self, intermidiatePicklePath, residualPicklePath, includePairwiseTree, forceMerge=False):
        
        # read the results from pickle files
        with open(intermidiatePicklePath, 'rb') as handle:
            self.intermediateDict = pickle.load(handle)
        with open(residualPicklePath, 'rb') as handle:
            self.residualDict = pickle.load(handle)
        
        self._setForestOutputs(forceMerge)
        self._setTreeList(includePairwiseTree)
        
        
    # get descendant and motif information from pickle
    def _setForestOutputs(self, forceMerge):

        descendant_dict = dict()
        repCommMotifSeq_dict = dict()
        intermediate_list = sorted(self.intermediateDict.items(), key=lambda x : x[0])

        for item in intermediate_list:
            value = item[1] # get original dict value
            score = value[0]
            clusterName = value[1][0]
            memberSet = value[2]
            commonMotifSeq = value[1][1] # list of common motif seq.

            descendants = set()
            for member in memberSet:
                if forceMerge:
                    descendants.add(member)
                else:
                    if member[0] == "G":
                        for descendant in descendant_dict[member]:
                            descendants.add(descendant)
                    else:
                        descendants.add(member)
            descendant_dict[clusterName] = descendants
            repCommMotifSeq_dict[clusterName] = commonMotifSeq

        self.descendant_dict = descendant_dict
        self.repCommMotifSeq_dict = repCommMotifSeq_dict
    
    
    # get those residual trees which isn't sigular
    # collect their clusterName into notLonerList.
    def _setTreeList(self, includePairwiseTree):

        notLonerList = []

        for key, value in self.residualDict.items():
            clusterName = value[0][0]
            motifsList = value[0][1]
            members = value[1]

            notLoner = False

            if(len(members) > 1):
                if(includePairwiseTree):
                    notLoner = True

                else:   # remove 2-member pairs
                    if( len(members) == 2):
                        for member in members:
                            if member[0] == 'G':
                                notLoner = True
                                break
                    else:
                        notLoner = True

            if(notLoner):
                notLonerList.append((clusterName, members))

        notLonerList = sorted(notLonerList, key=lambda x: int(x[0][1::]), reverse=False)

        self.treeList = notLonerList

#     def getGroupMotif_dict(self): # get motif sequence of each group (not only tree root)
#         return self.groupMotif_dict
        
    def getDescendant_dict(self): # get all descendant list(including root and middle nodes)
        return self.descendant_dict
    
    def getTreeList(self): # get tree root list.
        return self.treeList
    
    def getTreeRootNameList(self):
        nameList = list()
        for treeRoot in self.treeList:
            rootName = treeRoot[0] # treeRoot = (ParentNodeName, {children_Node_Names})
            nameList.append(rootName)
        return nameList
    
    def getTreeRootCount(self): # get how many trees in forest
        return len(self.treeList)
    
    def getForestMembers(self):
        forestMemberSet = set()
        trMember_dict = self.getTreeMembers_dict()
        for rootName in trMember_dict:
            members = trMember_dict[rootName]
            forestMemberSet.update(members)
        return forestMemberSet
    
    def getForestMemberCount(self): # return how many malwares in forest
        return len(self.getForestMembers())
    
    def getTreeMembers_dict(self): # key: treeRootName; val: treeMemberSet
        treeMember_dict = dict()
        rootNames = self.getTreeRootNameList()
        for rootName in rootNames:
            members = self.descendant_dict[rootName] # get Node's all descendants
            treeMember_dict[rootName] = members
        return treeMember_dict
    
    def getTreeMembers(self, rootName): # return members (set) in specific treeRoot
        trMember_dict = self.getTreeMembers_dict()
        return trMember_dict[rootName] # type == set()
    
    def getTreeSamples(self, rootName): # return how many samples in a tree (for PE files)
        members = self.getTreeMembers(rootName)
        samples = set()
        for mem in members:
            samples.add(mem.split('_')[0])
        return samples
    
    def getRepAPISeq_dict(self): # key: treeRootName; val: RepAPISeq <list>
        repAPISeq_dict = dict()
        rootNames = self.getTreeRootNameList()
        for rootName in rootNames:
            repAPISeq = self.getRepAPISeq(rootName) # get Rep API Seq of each root
            repAPISeq_dict[rootName] = repAPISeq # add into dict
        return repAPISeq_dict
            
    def getRepAPISeq(self, rootName): # get Rep API Seq of root 數字
        repMotifList = list()
        commMotifSeq = self.repCommMotifSeq_dict[rootName] # get CMS list
        
        commonAPISeq = [] # merge all motif's APIs
        for motifAPI in commMotifSeq:
            commonAPISeq.extend(motifAPI)
        return commonAPISeq
    
    def getRepMotifCount(self, rootName): # get motif count of root 
        commMotifSeq = self.repCommMotifSeq_dict[rootName]
        return len(commMotifSeq)
    
    def getRepMotifSequence(self, rootName): # get motif sequence of root 文字
        return self.repCommMotifSeq_dict[rootName]

In [15]:
# ### unit test

pkl_dir_path = 'output/RasMMA_forest/40.picsys_0.8/pickle/'
interPkl = pkl_dir_path + '40.picsys_0.8_intermediate.pickle'
resPkl = pkl_dir_path + '40.picsys_0.8_residual.pickle'
TreeUtil = CollectForestInfo
testFamilyForest = TreeUtil(interPkl, resPkl, True)



for root in rootNames:
    rootAPISeq = testFamilyForest.getRepAPISeq(root)
    motifCount = testFamilyForest.getRepMotifCount(root)
    print(len(rootAPISeq), motifCount)
    
    motifSeq = testFamilyForest.getRepAPISeq(root)
    motifLenList = [len(motif) for motif in motifSeq]
    print(motifLenList)

201 5
[108, 86, 36, 37, 37, 46, 102, 74, 100, 86, 107, 36, 36, 86, 74, 93, 134, 110, 131, 34, 35, 73, 78, 88, 79, 79, 36, 66, 79, 92, 110, 82, 94, 104, 111, 122, 102, 112, 93, 98, 112, 114, 92, 97, 92, 104, 92, 206, 92, 113, 92, 97, 92, 105, 92, 117, 92, 100, 92, 107, 92, 93, 92, 94, 79, 36, 66, 79, 79, 36, 66, 79, 79, 36, 96, 79, 96, 79, 36, 96, 79, 96, 79, 36, 96, 79, 96, 79, 36, 79, 96, 79, 79, 79, 36, 96, 79, 96, 79, 36, 66, 79, 36, 66, 79, 36, 66, 79, 79, 36, 96, 79, 96, 79, 36, 96, 79, 96, 79, 36, 96, 79, 96, 79, 36, 96, 79, 96, 79, 36, 96, 79, 96, 79, 36, 79, 36, 96, 79, 96, 79, 36, 79, 36, 96, 79, 96, 79, 96, 79, 96, 79, 36, 96, 79, 96, 79, 36, 96, 79, 96, 79, 36, 79, 36, 79, 36, 79, 36, 66, 79, 79, 36, 79, 36, 79, 36, 79, 36, 79, 79, 36, 79, 79, 36, 79, 51, 44, 51, 44, 110, 79, 36, 96, 79, 96, 165, 280, 89, 92, 85]
76 4
[35, 37, 55, 108, 86, 36, 86, 46, 102, 74, 100, 86, 74, 37, 35, 36, 33, 102, 131, 109, 116, 112, 101, 1885, 101, 1885, 101, 1885, 101, 1885, 101, 1885, 101, 18

In [11]:
rootNames

['G80', 'G111', 'G114']

In [16]:
rootNames = testFamilyForest.getTreeRootNameList()
for roots in rootNames:
    kk = testFamilyForest.getTreeMembers(roots) # 各tree底下有哪些hooklogs
    mm = testFamilyForest.getRepMotifSequence(roots)  # 各tree的REP是誰
print(kk)
print(mm)

{'4b0f08_3208', '00c03e_3180', '60df90_3416', '5bf528_3316', '701a96_3316', '73f36e_3276', 'ad9850_3204', 'e7b5bc_3200', '0720d6_3344', '0c9440_3316', 'bb4c17_3372', '4f002d_3112', '452eea_2900', 'e37a94_3412', '30aeef_3352', '91cb7c_3328', '31917e_3248', 'c68fd2_3208', '53c295_3320', '00c03e_3256', 'e0966a_2876', 'c68fd2_3168', '1d412b_3052', 'bed569_3272', 'd76b44_3316', '31917e_3252', '58eca9_3312', 'c2896b_2936', 'f1c2e7_3224', 'bd5efb_3264', '1e9b7a_3368', '0a7f85_2964'}
[['RegQueryValue#PR@HKLM@sys_curCtlSet_ctl_sessionManager\\*#PR@SUBK@criticalsectiontimeout#PR@0#PR@12f9b0#Ret#0', 'RegQueryValue#PR@HKLM@soft_ms_ole\\*#PR@SUBK@rwlockresourcetimeout#PR@0#PR@12f9b4#Ret#P'], ['LoadLibrary#PR@USR@malware@ENU#Ret#N', 'LoadLibrary#PR@USR@malware@EN#Ret#N', 'RegQueryValue#PR@HKLM@soft_ms_win_currentversion\\setup#PR@SUBK@version#PR@0#PR@44ccec#Ret#P', 'RegSetValue#PR@HKLM@soft_ms_win_currentversion\\setup\\version#PR@REG_DWORD#PR@131#Ret#0', 'CreateFile#PR@SYS@EXE#PR@GENERIC_READ;GENER