In [1]:
import re
import dgl
import torch
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F
import networkx as nx

from pathlib import Path
from androguard.misc import AnalyzeAPK
import pickle
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
import sklearn.metrics as M

from dgl.nn.pytorch import GraphConv, GATConv, SAGEConv
from sklearn.model_selection import StratifiedShuffleSplit

import joblib as J

Using backend: pytorch


In [2]:
#%xmode verbose

## Params

In [3]:
model_kwargs = {'in_dim': 15, 'hidden_dim': 30, 'n_classes': 5 }

In [4]:
train = False

In [5]:
extract = False

## Dataset

In [6]:
def get_samples(base_path):
    base_path = Path(base_path)
    labels_dict = {x:i for i,x in enumerate(sorted(["Adware", "Benigh", "Banking", "SMS", "Riskware"]))}
    if not base_path.exists():
        raise Exception(f'{base_path} does not exist')
    apk_list = sorted([x for x in base_path.iterdir() if not x.is_dir()])
    samples = []
    labels  = {}
    for apk in apk_list:
        samples.append(apk.name)
        labels[apk.name] = labels_dict[re.findall(r'[A-Z](?:[a-z]|[A-Z])+',apk.name)[0]]
    return samples, labels

In [7]:
samples, labels = get_samples('../data/large/raw')

In [8]:
samples[0]

'Adware0000.apk'

In [9]:
raw_prefix = Path('../data/large/raw')
processed_prefix = Path('../data/large/G-feat')

In [10]:
def process(file):
    _, _, dx = AnalyzeAPK(raw_prefix/file)
    cg = dx.get_call_graph()
    opcodes = {}
    for node in cg.nodes():
        sequence = [0] * 15
        if not node.is_external():
            for instr in node.get_method().get_instructions():
                value = instr.get_op_value()
                if value == 0x00: # nop
                    sequence[0] = 1
                elif value >= 0x01 and value <= 0x0D: # mov
                    sequence[1] = 1
                elif value >= 0x0E and value <= 0x11: # return
                    sequence[2] = 1
                elif value == 0x1D or value == 0x1E: # monitor
                    sequence[3] = 1
                elif value >= 0x32 and value <= 0x3D: # if
                    sequence[4] = 1
                elif value == 0x27: # throw
                    sequence[5] = 1
                elif value == 0x28 or value == 0x29: #goto
                    sequence[6] = 1
                elif value >= 0x2F and value <= 0x31: # compare
                    sequence[7] = 1
                elif value >= 0x7F and value <= 0x8F: # unop
                    sequence[8] = 1
                elif value >=90 and value <= 0xE2: # binop
                    sequence[9] = 1
                elif value == 0x21 or (value >= 0x23 and value <= 0x26) or (value >= 0x44 and value <= 0x51): # aop
                    sequence[10] = 1
                elif (value >= 0x52 and value <= 0x5F) or (value >= 0xF2 and value <= 0xF7): # instanceop
                    sequence[11] = 1
                elif (value >= 0x60 and value <= 0x6D): # staticop
                    sequence[12] = 1
                elif (value >= 0x6E and value <= 0x72) and (value >= 0x74 and value <= 0x78) and (value >= 0xF9 and value <= 0xFB):
                    sequence[13] = 1
                elif (value >= 0x22 and value <= 0x25):
                    sequence[14] = 1
        opcodes[node] = {'sequence': sequence}
    nx.set_node_attributes(cg, opcodes)
    labels = {x: {'name': x.full_name} for x in cg.nodes()}
    nx.set_node_attributes(cg, labels)
    cg = nx.convert_node_labels_to_integers(cg)
    torch.save(cg, processed_prefix/ (file.split('.')[0]+'.graph'))

In [11]:
if extract:
    J.Parallel(n_jobs=40)(J.delayed(process)(x) for x in samples);

In [12]:
class MalwareDataset(torch.utils.data.Dataset):
    def __init__(self, save_dir, list_IDs, labels):
        self.save_dir = Path(save_dir)
        self.list_IDs = list_IDs
        self.labels = labels
        self.cache = {}

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_IDs)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        if index not in self.cache:
            ID = self.list_IDs[index]
            graph_path = self.save_dir / (ID.split('.')[0] + '.graph')
            cg = torch.load(graph_path)
            dg = dgl.from_networkx(cg, node_attrs=['sequence'], edge_attrs=['offset'])
            dg = dgl.add_self_loop(dg)
            self.cache[index] = (dg, self.labels[ID])
        return self.cache[index]

## Data Loading

In [13]:
def split_dataset(samples, labels, ratios):
    if sum(ratios) != 1:
        raise Exception("Invalid ratios provided")
    train_ratio, val_ratio, test_ratio = ratios
    sss = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=0)
    train_idx, test_idx = list(sss.split(samples, [labels[x] for x in samples]))[0]
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio/(1-test_ratio), random_state=0)
    test_list = [samples[x] for x in test_idx]
    train_list = [samples[x] for x in train_idx]
    train_idx, val_idx = list(sss.split(train_list, [labels[x] for x in train_list]))[0]
    train_list = [samples[x] for x in train_idx]
    val_list = [samples[x] for x in val_idx]
    return train_list, val_list, test_list

In [14]:
train_list, val_list, test_list = split_dataset(samples, labels, [0.6, 0.2, 0.2])

In [15]:
torch.tensor([len(train_list), len(val_list), len(test_list)]).float()/len(samples)

tensor([0.6000, 0.2000, 0.2000])

In [16]:
def collate(samples):
    graphs, labels = [], []
    for graph, label in samples:
        graphs.append(graph)
        labels.append(label)
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)

In [17]:
train_dataset = MalwareDataset(processed_prefix , train_list, labels)
val_dataset   = MalwareDataset(processed_prefix , val_list,  labels)
test_dataset  = MalwareDataset(processed_prefix , test_list, labels)

In [18]:
train_data = DataLoader(train_dataset, batch_size=8, shuffle=True,  collate_fn=collate, num_workers=8)
val_data   = DataLoader(val_dataset,   batch_size=8, shuffle=False, collate_fn=collate , num_workers=40)
test_data  = DataLoader(test_dataset,  batch_size=8, shuffle=False, collate_fn=collate, num_workers=4)

In [19]:
len(test_dataset.cache)

0

## Model

In [20]:
class MalwareClassifier(pl.LightningModule):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hidden_dim, aggregator_type='mean')
        self.conv2 = SAGEConv(hidden_dim, hidden_dim, aggregator_type='mean')
        self.classify = nn.Linear(hidden_dim, n_classes)
        self.loss_func = nn.CrossEntropyLoss()
        
        
    def forward(self, g):
        h = g.ndata['sequence'].float()
        #h = torch.cat([g.ndata[x].view(-1,1).float() for x in ['public', 'entrypoint', 'external', 'native', 'codesize' ]], dim=1)
        # h = g.in_degrees().view(-1,1).float()
        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        g.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        hg = dgl.mean_nodes(g, 'h')
        return self.classify(hg) 
    
    def training_step(self, batch, batch_idx):
        bg, label = batch
        #print("Outer", len(label))
        prediction = self.forward(bg)
        loss = self.loss_func(prediction, label)
        return loss
    
    def validation_step(self, batch, batch_idx):
        bg, label = batch
        prediction = self.forward(bg)
        loss = self.loss_func(prediction, label)
        self.log('val_loss', loss)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [21]:
callbacks = [
    EarlyStopping(monitor='val_loss', patience=5, min_delta=0.01),
]

In [22]:
checkpointer = ModelCheckpoint(filepath='../models/3Nov-{epoch:02d}-{val_loss:.2f}.pt', monitor='val_loss', mode='min')

In [23]:
classifier= MalwareClassifier(**model_kwargs)
trainer = pl.Trainer(callbacks=callbacks, checkpoint_callback=checkpointer, gpus=[2])

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]


In [24]:
train

False

In [25]:
if train:
    trainer.fit(classifier, train_data, val_data)

## Testing 

In [26]:
classifier_saved = MalwareClassifier.load_from_checkpoint('../models/3Nov-epoch=36-val_loss=0.51.pt.ckpt', **model_kwargs)

In [33]:
classifier_saved(train_dataset[0][0])

tensor([[ -2.0630,   1.1638, -11.9895,   5.1457,  -1.6285]])

In [31]:
classifier_saved.freeze()

In [43]:
predicted = torch.argmax(classifier_saved(dgl.batch([g for g,l in test_dataset])),dim=1)
predicted

tensor([4, 2, 3,  ..., 1, 3, 2])

In [36]:
len(test_dataset)

3302

In [37]:
len(test_dataset.cache)

3302

In [44]:
actual = torch.tensor([l for g,l in test_dataset])
actual

tensor([4, 2, 3,  ..., 3, 3, 2])

In [45]:
print(M.classification_report(actual, predicted, digits=4))

              precision    recall  f1-score   support

           0     0.8911    0.7318    0.8036       302
           1     0.6159    0.8107    0.7000       449
           2     0.9124    0.9282    0.9202       808
           3     0.8524    0.7856    0.8176       779
           4     0.9707    0.9295    0.9497       964

    accuracy                         0.8610      3302
   macro avg     0.8485    0.8372    0.8382      3302
weighted avg     0.8730    0.8610    0.8640      3302



In [46]:
M.confusion_matrix(actual, predicted)

array([[221,  38,   5,  37,   1],
       [  5, 364,  19,  39,  22],
       [  5,  29, 750,  23,   1],
       [ 10, 106,  48, 612,   3],
       [  7,  54,   0,   7, 896]])

In [None]:
"Adware", "Benigh", "Banking", "SMS", "Riskware"

## Results
Accuracy - 86.10%,
Precision - 0.8485,
Recall - 0.8372,
F1 - 0.8382