In [1]:
import re
import dgl
import torch
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F
import networkx as nx

from pathlib import Path
from androguard.misc import AnalyzeAPK
import pickle
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import NeptuneLogger

import sklearn.metrics as M

from dgl.nn.pytorch import GraphConv, GATConv, SAGEConv
from sklearn.model_selection import StratifiedShuffleSplit

import joblib as J

Using backend: pytorch


In [2]:
def get_api_list(file):
    apis = open(file).readlines()
    return {x.strip(): i for i, x in enumerate(apis)}

In [3]:
api_list = get_api_list('api.list')

In [4]:
len(api_list)

226

## Params

In [5]:
model_kwargs = {'in_dim': len(api_list), 'hidden_dim': 64, 'n_classes': 1 }

In [6]:
train = True

In [7]:
extract = False

## Dataset

In [8]:
def get_samples(base_path):
    base_path = Path(base_path)
    labels_dict = {x:i for i,x in enumerate(sorted(["Adware", "Benigh", "Banking", "SMS", "Riskware"]))}
    if not base_path.exists():
        raise Exception(f'{base_path} does not exist')
    apk_list = sorted([x for x in base_path.iterdir() if not x.is_dir()])
    samples = []
    labels  = {}
    for apk in apk_list:
        samples.append(apk.name)
        labels[apk.name] = labels_dict[re.findall(r'[A-Z](?:[a-z]|[A-Z])+',apk.name)[0]]
    return samples, labels

In [9]:
samples, labels = get_samples('../data/large/raw')

In [10]:
raw_prefix = Path('../data/large/raw')
processed_prefix = Path('../data/large/APIFeatures')

In [11]:
def process(file):
    _, _, dx = AnalyzeAPK(raw_prefix/file)
    cg = dx.get_call_graph()
    mappings = {}
    #print(set(map(lambda x: x.full_name.split(';')[0][1:], filter(lambda x: x.is_external(), cg.nodes()))))
    #return
    for node in cg.nodes():
        mapping = {"api_package": None}
        if node.is_external():
            name = '.'.join(map(str, node.full_name.split(';')[0][1:].split('/')[:-2]))
            index = api_list.get(name, None)
            mapping["api_package"] = index
        mappings[node] = mapping
    nx.set_node_attributes(cg, mappings)
    labels = {x: {'name': x.full_name} for x in cg.nodes()}
    nx.set_node_attributes(cg, labels)
    cg = nx.convert_node_labels_to_integers(cg)
    #return cg
    torch.save(cg, processed_prefix/ (file.split('.')[0]+'.graph'))

In [12]:
if extract:
    J.Parallel(n_jobs=40)(J.delayed(process)(x) for x in samples);

In [13]:
class MalwareDataset(torch.utils.data.Dataset):
    def __init__(self, save_dir, list_IDs, labels):
        self.save_dir = Path(save_dir)
        self.list_IDs = list_IDs
        self.labels = labels
        self.cache = {}

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_IDs)
    
    def get_node_vector(self, pos):
        vector = torch.zeros(len(api_list))
        if pos:
            vector[pos] = 1
        return vector

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        if index not in self.cache:
            ID = self.list_IDs[index]
            graph_path = self.save_dir / (ID.split('.')[0] + '.graph')
            cg = torch.load(graph_path)
            feature = {n: self.get_node_vector(pos) for n, pos in nx.get_node_attributes(cg, 'api_package').items()}
            nx.set_node_attributes(cg, feature, 'feature')
            dg = dgl.from_networkx(cg, node_attrs=['feature'])
            dg = dgl.add_self_loop(dg)
            self.cache[index] = (dg, self.labels[ID])
        return self.cache[index]

## Data Loading

In [14]:
def split_dataset(samples, labels, ratios):
    if sum(ratios) != 1:
        raise Exception("Invalid ratios provided")
    train_ratio, val_ratio, test_ratio = ratios
    sss = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=0)
    train_idx, test_idx = list(sss.split(samples, [labels[x] for x in samples]))[0]
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio/(1-test_ratio), random_state=0)
    test_list = [samples[x] for x in test_idx]
    train_list = [samples[x] for x in train_idx]
    train_idx, val_idx = list(sss.split(train_list, [labels[x] for x in train_list]))[0]
    train_list = [samples[x] for x in train_idx]
    val_list = [samples[x] for x in val_idx]
    return train_list, val_list, test_list

In [15]:
train_list, val_list, test_list = split_dataset(samples, labels, [0.6, 0.2, 0.2])

In [16]:
torch.tensor([len(train_list), len(val_list), len(test_list)]).float()/len(samples)

tensor([0.6000, 0.2000, 0.2000])

In [17]:
def collate(samples):
    graphs, labels = [], []
    for graph, label in samples:
        graphs.append(graph)
        labels.append(label)
    batched_graph = dgl.batch(graphs)
    labels = torch.tensor(labels)
    labels[labels!=2] = 1
    labels[labels==2] = 0
    return batched_graph, labels.float()

In [18]:
train_dataset = MalwareDataset(processed_prefix , train_list, labels)
val_dataset   = MalwareDataset(processed_prefix , val_list,  labels)
test_dataset  = MalwareDataset(processed_prefix , test_list, labels)

In [19]:
len(train_dataset.cache), len(val_dataset.cache), len(test_dataset.cache)

(0, 0, 0)

In [20]:
train_data = DataLoader(train_dataset, batch_size=8, shuffle=True,  collate_fn=collate, num_workers=8)
val_data   = DataLoader(val_dataset,   batch_size=8, shuffle=False, collate_fn=collate , num_workers=40)
test_data  = DataLoader(test_dataset,  batch_size=8, shuffle=False, collate_fn=collate, num_workers=4)

In [45]:
test_data  = DataLoader(test_dataset,  batch_size=8, shuffle=False, collate_fn=collate) #num_workers=4)

## Model

In [21]:
class MalwareClassifier(pl.LightningModule):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hidden_dim, aggregator_type='mean')
        self.conv2 = SAGEConv(hidden_dim, hidden_dim, aggregator_type='mean')
        self.classify = nn.Linear(hidden_dim, n_classes)
        self.loss_func = nn.BCELoss()
        
    def forward(self, g):
        h = g.ndata['feature']
        #h = torch.cat([g.ndata[x].view(-1,1).float() for x in ['public', 'entrypoint', 'external', 'native', 'codesize' ]], dim=1)
        # h = g.in_degrees().view(-1,1).float()
        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        g.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        hg = dgl.sum_nodes(g, 'h')
        return torch.sigmoid(self.classify(hg)).squeeze()
    
    def training_step(self, batch, batch_idx):
        bg, label = batch
        #print("Outer", len(label))
        prediction = self.forward(bg)
        loss = self.loss_func(prediction, label)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        bg, label = batch
        prediction = self.forward(bg)
        loss = self.loss_func(prediction, label)
        self.log('val_loss', loss)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [22]:
callbacks = [
    EarlyStopping(monitor='val_loss', patience=5, min_delta=0.01),
]

In [23]:
checkpointer = ModelCheckpoint(filepath='../models/11Nov-{epoch:02d}-{val_loss:.2f}.pt', monitor='val_loss', mode='min')

In [24]:
neptune_logger = NeptuneLogger(
    api_key='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiYTY4NWRlMDMtMDMwOC00YmM2LWEwYmYtNmQzYjBmZmIzNjE4In0=',  # replace with your own
    project_name='vinayakakv/gcn-android-malware',
    experiment_name='Binary Classification',  # Optional,
    params=model_kwargs,  # Optional,
    tags=['binary_classification', 'MalDroid2020'],  # Optional,
)

https://ui.neptune.ai/vinayakakv/gcn-android-malware/e/GCNAN-6


NeptuneLogger will work in online mode


In [25]:
classifier= MalwareClassifier(**model_kwargs)
trainer = pl.Trainer(callbacks=callbacks, checkpoint_callback=checkpointer, logger=neptune_logger, gpus=[2])

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]


In [28]:
train

True

In [None]:
if train:
    trainer.fit(classifier, train_data, val_data)


  | Name      | Type     | Params
---------------------------------------
0 | conv1     | SAGEConv | 29 K  
1 | conv2     | SAGEConv | 8 K   
2 | classify  | Linear   | 65    
3 | loss_func | BCELoss  | 0     


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

In [31]:
trainer.current_epoch

5

trainering 

In [32]:
classifier_saved = MalwareClassifier.load_from_checkpoint('../models/11Nov-epoch=00-val_loss=29.52.pt.ckpt', **model_kwargs)

In [41]:
classifier_saved = classifier_saved.to('cpu')

In [38]:
with torch.no_grad():
    data = dgl.batch([g.to('cpu') for g,l in test_dataset])

In [43]:
predicted = []

In [48]:
with torch.no_grad():
    for data,_ in test_data:
        prediction = classifier(data)
        predicted.append(prediction)

In [49]:
predicted

[tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([1.

In [34]:
predicted = classifier()
predicted

RuntimeError: [enforce fail at CPUAllocator.cpp:64] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 43235428104 bytes. Error code 12 (Cannot allocate memory)

In [None]:
predicted

In [61]:
actual = torch.tensor([l for g,l in test_dataset])
actual

tensor([4, 2, 3,  ..., 3, 3, 2])

In [62]:
print(M.classification_report(actual, predicted, digits=4))

              precision    recall  f1-score   support

           0     0.9139    0.8079    0.8576       302
           1     0.7181    0.6526    0.6838       449
           2     0.9272    0.8824    0.9042       808
           3     0.7611    0.7728    0.7669       779
           4     0.8641    0.9564    0.9079       964

    accuracy                         0.8401      3302
   macro avg     0.8369    0.8144    0.8241      3302
weighted avg     0.8399    0.8401    0.8387      3302



In [63]:
M.confusion_matrix(actual, predicted)

array([[244,  15,   2,  37,   4],
       [  7, 293,  21,  70,  58],
       [  5,  25, 713,  60,   5],
       [ 11,  55,  33, 602,  78],
       [  0,  20,   0,  22, 922]])

In [64]:
sorted(["Adware", "Benigh", "Banking", "SMS", "Riskware"])

['Adware', 'Banking', 'Benigh', 'Riskware', 'SMS']

In [75]:
predicted[15]

tensor(2)

In [73]:
torch.where(actual!=predicted)[0]

tensor([   8,   15,   20,   25,   50,   56,   57,   58,   62,   67,   72,   73,
          74,   75,   76,   96,   97,   99,  100,  114,  120,  121,  123,  125,
         126,  131,  138,  140,  143,  158,  159,  186,  187,  211,  212,  221,
         233,  239,  241,  263,  271,  281,  284,  297,  298,  305,  306,  309,
         310,  313,  316,  318,  321,  329,  333,  335,  342,  355,  359,  360,
         368,  370,  388,  390,  392,  396,  407,  411,  412,  426,  429,  434,
         435,  442,  450,  453,  456,  459,  467,  469,  470,  471,  475,  476,
         498,  505,  512,  522,  526,  528,  541,  546,  552,  553,  555,  557,
         560,  567,  573,  577,  578,  579,  589,  602,  616,  618,  619,  625,
         632,  642,  652,  653,  656,  658,  663,  664,  666,  672,  677,  682,
         685,  690,  705,  708,  714,  728,  733,  734,  735,  744,  747,  775,
         780,  781,  785,  798,  800,  802,  804,  811,  820,  825,  829,  830,
         833,  838,  842,  846,  865,  8

In [70]:
import numpy as np

In [71]:
test_list_np = np.array(test_list)

In [72]:
test_list_np[torch.where(actual!=predicted)[0].numpy()]

array(['Adware0767.apk', 'Banking1835.apk', 'Riskware0594.apk',
       'Banking0851.apk', 'Riskware4216.apk', 'Riskware2248.apk',
       'Banking0452.apk', 'Benigh0310.apk', 'Banking1730.apk',
       'Banking0032.apk', 'Benigh0353.apk', 'Banking1341.apk',
       'Banking0427.apk', 'Benigh2838.apk', 'Banking1769.apk',
       'Banking1240.apk', 'Riskware4175.apk', 'Riskware1607.apk',
       'Riskware0891.apk', 'Riskware2847.apk', 'Riskware4066.apk',
       'Banking1974.apk', 'Benigh0736.apk', 'Riskware1756.apk',
       'Adware0635.apk', 'Adware0093.apk', 'Banking0844.apk',
       'Riskware3207.apk', 'SMS2836.apk', 'Benigh0740.apk',
       'Banking1832.apk', 'Riskware2993.apk', 'Banking0169.apk',
       'Riskware0301.apk', 'Banking0434.apk', 'Riskware1234.apk',
       'Riskware3747.apk', 'Banking0577.apk', 'Banking1982.apk',
       'Benigh1585.apk', 'Banking0099.apk', 'Banking0505.apk',
       'Banking2505.apk', 'Banking0772.apk', 'Benigh3860.apk',
       'Riskware2045.apk', 'Riskware3325

## Results
Accuracy - 84.01%,
Precision - 0.8369,
Recall - 0.8144,
F1 - 0.8241