In [1]:
import re
import dgl
import torch

from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F
import networkx as nx

from pathlib import Path
from androguard.misc import AnalyzeAPK

import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
import sklearn.metrics as M


from dgl.nn.pytorch import GraphConv, GATConv, SAGEConv
from sklearn.model_selection import StratifiedShuffleSplit

Using backend: pytorch


## Params

In [2]:
model_kwargs = {'in_dim': 5, 'hidden_dim': 64, 'n_classes': 1 }

In [3]:
train = True

## Dataset

In [4]:
def get_samples(base_path):
    base_path = Path(base_path)
    labels_dict = {x:i for i,x in enumerate(sorted(["Adware", "Benigh", "Banking", "SMS", "Riskware"]))}
    if not base_path.exists():
        raise Exception(f'{base_path} does not exist')
    apk_list = sorted([x for x in base_path.iterdir() if not x.is_dir()])
    samples = []
    labels  = {}
    for apk in apk_list:
        samples.append(apk.name)
        labels[apk.name] = labels_dict[re.findall(r'[A-Z](?:[a-z]|[A-Z])+',apk.name)[0]]
    return samples, labels
    

In [5]:
class MalwareDataset(torch.utils.data.Dataset):
    
    def __init__(self, raw_dir, save_dir, list_IDs, labels):
        self.raw_dir = Path(raw_dir)
        self.save_dir = Path(save_dir)
        self.list_IDs = list_IDs
        self.labels = labels
        self.cache = {}
    
    def process(self, apk_file):
        '''
        We assume that data is stored in self.raw_dir
        It is an 1-level folder, where the folder contains apks as items
        Each apk is named ClassXXXX.apk where XXXX is a digit
        Writes processed graphs to self.save_dir, where
        Each processed file is named ClassXXXX.graph
        '''
        a, d, dx = AnalyzeAPK(Path(self.raw_dir)/apk_file)
        cg = dx.get_call_graph()
        labels = {x: {'name': x.full_name} for x in cg.nodes()}
        nx.set_node_attributes(cg, labels)
        cg = nx.convert_node_labels_to_integers(cg)
        dg = dgl.from_networkx(cg, node_attrs=['external', 'entrypoint', 'native', 'public', 'codesize'], edge_attrs=['offset'])
        dg = dgl.add_self_loop(dg)
        torch.save(dg, self.save_dir/(apk_file.split('.')[0] + '.graph'))

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_IDs)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        if index not in self.cache:
            ID = self.list_IDs[index]
            graph_path = self.save_dir / (ID.split('.')[0] + '.graph')
            if not graph_path.exists():
                self.process(ID)
            self.cache[index] = (torch.load(graph_path), self.labels[ID])
        return self.cache[index]

## Data Loading

In [6]:
samples, labels = get_samples('../data/large/raw')

In [7]:
def split_dataset(samples, labels, ratios):
    if sum(ratios) != 1:
        raise Exception("Invalid ratios provided")
    train_ratio, val_ratio, test_ratio = ratios
    sss = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=0)
    train_idx, test_idx = list(sss.split(samples, [labels[x] for x in samples]))[0]
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio/(1-test_ratio), random_state=0)
    test_list = [samples[x] for x in test_idx]
    train_list = [samples[x] for x in train_idx]
    train_idx, val_idx = list(sss.split(train_list, [labels[x] for x in train_list]))[0]
    train_list = [samples[x] for x in train_idx]
    val_list = [samples[x] for x in val_idx]
    return train_list, val_list, test_list

In [8]:
train_list, val_list, test_list = split_dataset(samples, labels, [0.6, 0.2, 0.2])

In [9]:
torch.tensor([len(train_list), len(val_list), len(test_list)]).float()/len(samples)

tensor([0.6000, 0.2000, 0.2000])

In [10]:
def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    labels = torch.tensor(labels)
    labels[labels!=2] = 0
    labels[labels==2] = 1
    return batched_graph, labels.float()

In [11]:
train_dataset = MalwareDataset('../data/large/raw', '../data/large/processed', train_list, labels)
val_dataset   = MalwareDataset('../data/large/raw', '../data/large/processed', val_list,  labels)
test_dataset  = MalwareDataset('../data/large/raw', '../data/large/processed', test_list, labels)

train_data = DataLoader(train_dataset, batch_size=32, shuffle=True,  collate_fn=collate, num_workers=10)
val_data   = DataLoader(val_dataset,   batch_size=32, shuffle=False, collate_fn=collate, num_workers=10)
test_data  = DataLoader(test_dataset,  batch_size=32, shuffle=False, collate_fn=collate, num_workers=10)

## Model

In [12]:
class MalwareClassifier(pl.LightningModule):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hidden_dim, aggregator_type='mean')
        self.conv2 = SAGEConv(hidden_dim, hidden_dim, aggregator_type='mean')
        self.classify = nn.Linear(hidden_dim, n_classes)
        self.loss_func = nn.BCELoss()
        
    def forward(self, g):
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.
        h = torch.cat([g.ndata[x].view(-1,1).float() for x in ['public', 'entrypoint', 'external', 'native', 'codesize' ]], dim=1)
        # h = g.in_degrees().view(-1,1).float()
        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        g.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        hg = dgl.mean_nodes(g, 'h')
        return torch.sigmoid(self.classify(hg)).squeeze()
    
    def training_step(self, batch, batch_idx):
        bg, label = batch
        prediction = self.forward(bg)
        loss = self.loss_func(prediction, label)
        return loss
    
    def validation_step(self, batch, batch_idx):
        bg, label = batch
        prediction = self.forward(bg)
        loss = self.loss_func(prediction, label)
        self.log('val_loss', loss)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [13]:
callbacks = [
    EarlyStopping(monitor='val_loss', patience=5, min_delta=0.01),
]

In [14]:
checkpointer = ModelCheckpoint(filepath='../models/12Nov-{epoch:02d}-{val_loss:.2f}.pt', monitor='val_loss', mode='min')

In [15]:
classifier= MalwareClassifier(**model_kwargs)
trainer = pl.Trainer(callbacks=callbacks, checkpoint_callback=checkpointer, gpus=[3])

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]


In [16]:
if train:
    trainer.fit(classifier, train_data, val_data)


  | Name      | Type     | Params
---------------------------------------
0 | conv1     | SAGEConv | 768   
1 | conv2     | SAGEConv | 8 K   
2 | classify  | Linear   | 65    
3 | loss_func | BCELoss  | 0     


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




## Testing 

In [17]:
classifier = MalwareClassifier.load_from_checkpoint('../models/12Nov-epoch=15-val_loss=0.28.pt.ckpt', **model_kwargs)

In [20]:
predicted = classifier(dgl.batch([g for g,l in test_dataset]))
predicted

tensor([9.3014e-04, 5.8755e-01, 3.4881e-04,  ..., 7.8131e-01, 8.5041e-02,
        6.6010e-01], grad_fn=<SqueezeBackward0>)

In [21]:
predicted_mod = predicted.detach()

In [23]:
predicted_mod[predicted_mod>0.5] = 1
predicted_mod[predicted_mod<0.5] = 0

In [29]:
predicted_mod.long()

tensor([0, 1, 0,  ..., 1, 0, 1])

In [25]:
actual = torch.tensor([l for g,l in test_dataset])
actual[actual!=2]=0
actual[actual==2]=1
actual

tensor([0, 1, 0,  ..., 0, 0, 1])

In [47]:
len(actual), len(torch.where(actual==0)[0])

(3302, 2494)

In [48]:
_[0]-_[1]

808

In [31]:
print(M.classification_report(actual, predicted_mod.long(), digits=4))

              precision    recall  f1-score   support

           0     0.9594    0.9379    0.9485      2494
           1     0.8206    0.8775    0.8481       808

    accuracy                         0.9231      3302
   macro avg     0.8900    0.9077    0.8983      3302
weighted avg     0.9254    0.9231    0.9239      3302



In [49]:
M.confusion_matrix(actual, predicted_mod.long())

array([[2339,  155],
       [  99,  709]])

## Results
Accuracy - 93.21%,
Precision - 0.9254,
Recall - 0.9231,
F1 - 0.9239