In [1]:
import os
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import multiprocessing as mp
from functools import partial
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
# from src.model import *
# from src.util import *
os.environ["CUDA_VISIBLE_DEVICES"] = '2'

In [2]:
trained_model_path = '/workdir/security/home/junjiehuang2468/paper/trained_models_weight/ember/'
best_trained_model = '2022-01-18 14:55/2w_epoch:0_test_acc:0.890858.pt'
data_path = "/workdir/security/home/junjiehuang2468/paper/data/ember2018/"
train_data_path = data_path + "malwares/" 
test_data_path = data_path + "test_malwares/" 

In [3]:
CUDA = True if torch.cuda.is_available() else False
NUM_WORKERS = 18
BATCH_SIZE = 10
LEAVE_BIT_NUMBER = 20000
KERNEL_SIZE = 500

In [4]:
trainset = pd.read_csv(data_path + 'train_dataset.csv')
validset = pd.read_csv(data_path + 'valid_dataset.csv')
testset = pd.read_csv(data_path + 'test_dataset.csv')
testset = testset.iloc[np.argwhere(testset['labels'].values == 1).squeeze(),:]

In [5]:
class ExeDataset(Dataset):
    def __init__(self, malware_names, data_path, labels, leave_bit_num):
        self.malware_names = malware_names
        self.data_path = data_path
        self.labels = labels
        self.leave_bit_num = leave_bit_num

    def __len__(self):
        return len(self.malware_names)

    def __getitem__(self, idx):
        with open(self.data_path + self.malware_names[idx] + '.txt','rb') as fp:
            data = [bit+1 for bit in fp.read()[:self.leave_bit_num]]
            padding = [0]*(self.leave_bit_num-len(data))
            data = data + padding

        return np.array(data), np.array([self.labels[idx]])

In [6]:
train_dataset = ExeDataset(
    trainset["id"].tolist(), 
    train_data_path, 
    trainset["labels"].tolist(), 
    LEAVE_BIT_NUMBER
)
valid_dataset = ExeDataset(
    validset["id"].tolist(), 
    train_data_path, 
    validset["labels"].tolist(), 
    LEAVE_BIT_NUMBER
)
test_dataset = ExeDataset(
    testset["id"].tolist(), 
    test_data_path, 
    testset["labels"].tolist(), 
    LEAVE_BIT_NUMBER
)

In [7]:
trainloader = DataLoader(
    dataset = train_dataset,
    batch_size = BATCH_SIZE,
    shuffle = False,
    num_workers = NUM_WORKERS,
    pin_memory = True
)
validloader = DataLoader(
    dataset = valid_dataset,
    batch_size = BATCH_SIZE,
    shuffle = False,
    num_workers = NUM_WORKERS,
    pin_memory = True
)
testloader = DataLoader(
    dataset = test_dataset,
    batch_size = BATCH_SIZE,
    shuffle = False,
    num_workers = NUM_WORKERS,
    pin_memory = True
)

In [8]:
class Model(nn.Module):
    def __init__(self, data_length = 2e6, kernel_size = 500):
        super().__init__()
        self.embedding = nn.Embedding(257, 8, padding_idx=0)
        self.conv_layer_1 = nn.Conv1d(4, 128, kernel_size, stride = kernel_size, bias = True)
        # self.bn_1 = nn.BatchNorm1d(128)
        self.conv_layer_2 = nn.Conv1d(4, 128, kernel_size, stride = kernel_size, bias = True)
        self.pool_layer_2 = nn.MaxPool1d(data_length//kernel_size)
        self.fc_layer_3 = nn.Linear(128, 128)
        self.fc_layer_4 = nn.Linear(128, 2)
        
    def forward(self,x):
        embedd_x = self.embedding(x)
        embedd_x.retain_grad()
        x = embedd_x.transpose(-1,-2)
        x_conv_1 = self.conv_layer_1(x[:,:4,:])
        x_conv_2 = torch.sigmoid(self.conv_layer_2(x[:,4:,:]))
        x = x_conv_1*x_conv_2
        del x_conv_1,x_conv_2
        x = self.pool_layer_2(x).squeeze()
        x = self.fc_layer_3(x)
        x = self.fc_layer_4(x)
        return x,embedd_x

In [9]:
model = Model(data_length=LEAVE_BIT_NUMBER,kernel_size=KERNEL_SIZE)

ce_loss = nn.CrossEntropyLoss()
optim = Adam(model.parameters())

model = model.cuda() if CUDA else model
ce_loss = ce_loss.cuda() if CUDA else ce_less

In [10]:
model.load_state_dict(torch.load(trained_model_path + best_trained_model))

<All keys matched successfully>

In [11]:
model_embedding_layer = model.embedding

In [12]:
from collections import defaultdict
dic = defaultdict(list)

In [None]:
total_batch_acc = []
bar = tqdm(testloader)
# bar = tqdm(validloader)
for step, (batch_data,batch_label) in enumerate(bar):
    optim.zero_grad()
    temp_total_acc = []
    batch_data = batch_data.cuda() if CUDA else batch_data
    clone_batch_data = batch_data.detach().clone()
    batch_label = batch_label.cuda() if CUDA else batch_label
    batch_label = batch_label.squeeze()
    temp = torch.zeros((len(batch_label),2))
    for idx,target in enumerate(batch_label.squeeze()): temp[idx,target] = 1
    temp_label = temp.cuda() if CUDA else temp
    mask = (batch_data == 0).float() # wrong # I forgot what's wrong here
    pred,embedd_x = model(batch_data)
    _, predicted = torch.max(pred, 1)
    acc = (batch_label.cpu().data.numpy() == predicted.cpu().data.numpy()).mean()
    temp_total_acc.append(acc)

    random_padding = torch.randint_like(input=mask,low=1,high=257) * mask
    batch_data += random_padding.long()
    pred,embedd_x = model(batch_data)
    _, predicted = torch.max(pred, 1)
    acc = (batch_label.cpu().data.numpy() == predicted.cpu().data.numpy()).mean()
    temp_total_acc.append(acc)

    for time in range(8):
        pred,embedd_x = model(batch_data)
        pred.mean().backward()
        all_embedd = model.embedding(torch.arange(start=0,end=257).cuda())
        grad = torch.div(
            input = -embedd_x.grad,
            other = torch.linalg.norm(-embedd_x.grad,ord=2,axis=-1).unsqueeze(-1),
        )
        grad = grad.nan_to_num(0)
        shape = embedd_x.shape
        embedd_x = embedd_x.reshape((shape[0],shape[1],1,shape[2]))
        embedd_x  = embedd_x.expand(shape[0],shape[1],257,shape[2])
        shape = grad.shape
        grad = grad.reshape((shape[0],shape[1],1,shape[2]))
        grad = grad.expand(shape[0],shape[1],257,shape[2])
        sb = torch.sum(grad*(all_embedd - embedd_x),dim=-1)
        sb = sb.unsqueeze(-1)
        sb_mask = (sb.squeeze() > 0).float()
        db = torch.linalg.norm(all_embedd - (embedd_x+sb*grad),ord=1,axis=-1)
        # print(torch.any(torch.isnan(db)))
        result = sb_mask*db + (1-sb_mask)*(torch.max(db) + 1)
        result = torch.argmin(result,dim=-1)
        batch_data.data = (batch_data.data*(1-mask) + result*mask).long()
        pred,embedd_x = model(batch_data)
        _, predicted = torch.max(pred, 1)
        acc = (batch_label.cpu().data.numpy() == predicted.cpu().data.numpy()).mean()
        temp_total_acc.append(acc)

    pred = pred.detach().cpu().numpy()
    batch_label = batch_label.detach().cpu().numpy()
    pred = np.argmax(pred,1)
    temp_acc = (batch_label == pred).mean()
    temp = batch_label == pred
    if temp.mean() != 0:
        idxs = np.argwhere(temp).reshape(-1)
        for idx in idxs:
            print(step*BATCH_SIZE+idx)
            padding_num = (clone_batch_data[idx] == 0).sum().tolist()
            print(padding_num)
            dic[padding_num].append(step*BATCH_SIZE+idx)
    total_batch_acc.append(temp_total_acc)
    total_batch_acc_str = '[' + ' '.join(map(lambda x: '%.7f'%x,np.mean(total_batch_acc,axis=0))) + ']'
    bar.set_description(f'{total_batch_acc_str}')

[0.8730769 0.4692308 0.0038462 0.0038462 0.0038462 0.0038462 0.0038462 0.0038462 0.0038462 0.0038462]:   0%|          | 26/10000 [00:12<1:12:45,  2.28it/s]

253
0


[0.8780488 0.4756098 0.0048780 0.0048780 0.0048780 0.0048780 0.0048780 0.0048780 0.0048780 0.0048780]:   0%|          | 41/10000 [00:19<1:12:36,  2.29it/s]

400
0


[0.8600000 0.4538462 0.0046154 0.0046154 0.0046154 0.0046154 0.0046154 0.0046154 0.0046154 0.0046154]:   1%|          | 65/10000 [00:29<1:13:28,  2.25it/s]

647
0


[0.8652632 0.4652632 0.0042105 0.0042105 0.0042105 0.0042105 0.0042105 0.0042105 0.0042105 0.0042105]:   1%|          | 95/10000 [00:43<1:13:21,  2.25it/s]

942
0


[0.8691667 0.4566667 0.0041667 0.0041667 0.0041667 0.0041667 0.0041667 0.0041667 0.0041667 0.0041667]:   1%|          | 120/10000 [00:54<1:13:15,  2.25it/s]

1197
0


[0.8573427 0.4496503 0.0041958 0.0041958 0.0041958 0.0041958 0.0041958 0.0041958 0.0041958 0.0041958]:   1%|▏         | 143/10000 [01:04<1:13:20,  2.24it/s]

1423
0


[0.8591716 0.4502959 0.0041420 0.0041420 0.0041420 0.0041420 0.0041420 0.0041420 0.0041420 0.0041420]:   2%|▏         | 169/10000 [01:16<1:13:06,  2.24it/s]

1681
0


[0.8590643 0.4508772 0.0046784 0.0046784 0.0046784 0.0046784 0.0046784 0.0046784 0.0046784 0.0046784]:   2%|▏         | 171/10000 [01:16<1:13:06,  2.24it/s]

1700
0


[0.8586207 0.4517241 0.0051724 0.0051724 0.0051724 0.0051724 0.0051724 0.0051724 0.0051724 0.0051724]:   2%|▏         | 174/10000 [01:18<1:13:14,  2.24it/s]

1735
0


[0.8601124 0.4505618 0.0056180 0.0056180 0.0056180 0.0056180 0.0056180 0.0056180 0.0056180 0.0056180]:   2%|▏         | 178/10000 [01:20<1:13:05,  2.24it/s]

1770
0


[0.8589744 0.4512821 0.0056410 0.0056410 0.0056410 0.0056410 0.0056410 0.0056410 0.0056410 0.0056410]:   2%|▏         | 195/10000 [01:27<1:13:00,  2.24it/s]

1944
0


[0.8598985 0.4507614 0.0060914 0.0060914 0.0060914 0.0060914 0.0060914 0.0060914 0.0060914 0.0060914]:   2%|▏         | 197/10000 [01:28<1:12:56,  2.24it/s]

1964
0


[0.8581197 0.4452991 0.0055556 0.0055556 0.0055556 0.0055556 0.0055556 0.0055556 0.0055556 0.0055556]:   2%|▏         | 234/10000 [01:45<1:12:32,  2.24it/s]

2335
0


[0.8580913 0.4443983 0.0058091 0.0058091 0.0058091 0.0058091 0.0058091 0.0058091 0.0058091 0.0058091]:   2%|▏         | 241/10000 [01:48<1:12:39,  2.24it/s]

2408
0


[0.8581301 0.4455285 0.0060976 0.0060976 0.0060976 0.0060976 0.0060976 0.0060976 0.0060976 0.0060976]:   2%|▏         | 246/10000 [01:50<1:12:43,  2.24it/s]

2457
0


[0.8571429 0.4455598 0.0061776 0.0061776 0.0061776 0.0061776 0.0061776 0.0061776 0.0061776 0.0061776]:   3%|▎         | 259/10000 [01:56<1:12:15,  2.25it/s]

2586
0


[0.8573077 0.4450000 0.0065385 0.0065385 0.0065385 0.0065385 0.0065385 0.0065385 0.0065385 0.0065385]:   3%|▎         | 260/10000 [01:56<1:12:21,  2.24it/s]

2591
0


[0.8608997 0.4425606 0.0062284 0.0062284 0.0062284 0.0062284 0.0062284 0.0062284 0.0062284 0.0062284]:   3%|▎         | 289/10000 [02:09<1:12:26,  2.23it/s]

2882
0


[0.8612500 0.4437500 0.0059375 0.0059375 0.0059375 0.0059375 0.0059375 0.0059375 0.0059375 0.0059375]:   3%|▎         | 320/10000 [02:23<1:11:59,  2.24it/s]

3191
0


[0.8622093 0.4473837 0.0058140 0.0058140 0.0058140 0.0058140 0.0058140 0.0058140 0.0058140 0.0058140]:   3%|▎         | 344/10000 [02:34<1:11:41,  2.24it/s]

3432
0


[0.8617816 0.4465517 0.0060345 0.0060345 0.0060345 0.0060345 0.0060345 0.0060345 0.0060345 0.0060345]:   3%|▎         | 348/10000 [02:36<1:11:29,  2.25it/s]

3472
0


[0.8611268 0.4456338 0.0061972 0.0061972 0.0061972 0.0061972 0.0061972 0.0061972 0.0061972 0.0061972]:   4%|▎         | 355/10000 [02:39<1:11:32,  2.25it/s]

3549
0


[0.8616798 0.4443570 0.0060367 0.0060367 0.0060367 0.0060367 0.0060367 0.0060367 0.0060367 0.0060367]:   4%|▍         | 381/10000 [02:50<1:11:38,  2.24it/s]

3809
0


[0.8616580 0.4435233 0.0062176 0.0062176 0.0062176 0.0062176 0.0062176 0.0062176 0.0062176 0.0062176]:   4%|▍         | 386/10000 [02:52<1:11:22,  2.25it/s]

3858
0


[0.8615776 0.4437659 0.0063613 0.0063613 0.0063613 0.0063613 0.0063613 0.0063613 0.0063613 0.0063613]:   4%|▍         | 393/10000 [02:56<1:11:10,  2.25it/s]

3928
0


[0.8600467 0.4413551 0.0060748 0.0060748 0.0060748 0.0060748 0.0060748 0.0060748 0.0060748 0.0060748]:   4%|▍         | 428/10000 [03:11<1:11:04,  2.24it/s]

4275
0


[0.8596774 0.4412442 0.0062212 0.0062212 0.0062212 0.0062212 0.0062212 0.0062212 0.0062212 0.0062212]:   4%|▍         | 434/10000 [03:14<1:10:37,  2.26it/s]

4332
0


[0.8606407 0.4418764 0.0064073 0.0064073 0.0064073 0.0064073 0.0064073 0.0064073 0.0064073 0.0064073]:   4%|▍         | 437/10000 [03:15<1:10:44,  2.25it/s]

4366
0


[0.8611738 0.4419865 0.0065463 0.0065463 0.0065463 0.0065463 0.0065463 0.0065463 0.0065463 0.0065463]:   4%|▍         | 443/10000 [03:18<1:10:48,  2.25it/s]

4425
0


[0.8601732 0.4415584 0.0064935 0.0064935 0.0064935 0.0064935 0.0064935 0.0064935 0.0064935 0.0064935]:   5%|▍         | 462/10000 [03:26<1:11:07,  2.24it/s]

4619
0


[0.8593361 0.4398340 0.0064315 0.0064315 0.0064315 0.0064315 0.0064315 0.0064315 0.0064315 0.0064315]:   5%|▍         | 482/10000 [03:35<1:10:18,  2.26it/s]

4819
0


[0.8597980 0.4408081 0.0064646 0.0064646 0.0064646 0.0064646 0.0064646 0.0064646 0.0064646 0.0064646]:   5%|▍         | 495/10000 [03:41<1:10:29,  2.25it/s]

4941
0


[0.8606786 0.4403194 0.0065868 0.0065868 0.0065868 0.0065868 0.0065868 0.0065868 0.0065868 0.0065868]:   5%|▌         | 501/10000 [03:44<1:10:25,  2.25it/s]

5006
0


[0.8598837 0.4401163 0.0065891 0.0065891 0.0065891 0.0065891 0.0065891 0.0065891 0.0065891 0.0065891]:   5%|▌         | 516/10000 [03:50<1:10:08,  2.25it/s]

5151
0


[0.8607955 0.4412879 0.0066288 0.0066288 0.0066288 0.0066288 0.0066288 0.0066288 0.0066288 0.0066288]:   5%|▌         | 528/10000 [03:56<1:10:19,  2.24it/s]

5273
0


[0.8599631 0.4431734 0.0066421 0.0066421 0.0066421 0.0066421 0.0066421 0.0066421 0.0066421 0.0066421]:   5%|▌         | 542/10000 [04:02<1:09:59,  2.25it/s]

5410
0


[0.8600000 0.4421239 0.0065487 0.0065487 0.0065487 0.0065487 0.0065487 0.0065487 0.0065487 0.0065487]:   6%|▌         | 565/10000 [04:12<1:09:58,  2.25it/s]

5649
0


[0.8606112 0.4417657 0.0064516 0.0064516 0.0064516 0.0064516 0.0064516 0.0064516 0.0064516 0.0064516]:   6%|▌         | 589/10000 [04:23<1:09:54,  2.24it/s]

5886
0


[0.8610829 0.4419628 0.0065990 0.0065990 0.0065990 0.0065990 0.0065990 0.0065990 0.0065990 0.0065990]:   6%|▌         | 591/10000 [04:24<1:09:54,  2.24it/s]

5903
0


[0.8616162 0.4425926 0.0067340 0.0067340 0.0067340 0.0067340 0.0067340 0.0067340 0.0067340 0.0067340]:   6%|▌         | 594/10000 [04:25<1:09:49,  2.25it/s]

5936
0


[0.8607383 0.4422819 0.0068792 0.0068792 0.0068792 0.0068792 0.0068792 0.0068792 0.0068792 0.0068792]:   6%|▌         | 596/10000 [04:26<1:09:52,  2.24it/s]

5951
0


[0.8606965 0.4431177 0.0069652 0.0069652 0.0069652 0.0069652 0.0069652 0.0069652 0.0069652 0.0069652]:   6%|▌         | 603/10000 [04:29<1:09:48,  2.24it/s]

6029
0


[0.8609677 0.4459677 0.0069355 0.0069355 0.0069355 0.0069355 0.0069355 0.0069355 0.0069355 0.0069355]:   6%|▌         | 620/10000 [04:36<1:09:29,  2.25it/s]

6193
0


[0.8598726 0.4447452 0.0070064 0.0070064 0.0070064 0.0070064 0.0070064 0.0070064 0.0070064 0.0070064]:   6%|▋         | 628/10000 [04:40<1:09:24,  2.25it/s]

6278
0


[0.8593060 0.4451104 0.0070978 0.0070978 0.0070978 0.0070978 0.0070978 0.0070978 0.0070978 0.0070978]:   6%|▋         | 634/10000 [04:43<1:09:30,  2.25it/s]

6332
0


[0.8590551 0.4451969 0.0072441 0.0072441 0.0072441 0.0072441 0.0072441 0.0072441 0.0072441 0.0072441]:   6%|▋         | 635/10000 [04:43<1:09:26,  2.25it/s]

6341
0


[0.8562033 0.4424514 0.0070254 0.0070254 0.0070254 0.0070254 0.0070254 0.0070254 0.0070254 0.0070254]:   7%|▋         | 669/10000 [04:58<1:09:03,  2.25it/s]

6686
0


[0.8547850 0.4424411 0.0066574 0.0066574 0.0066574 0.0066574 0.0066574 0.0066574 0.0066574 0.0066574]:   7%|▋         | 721/10000 [05:21<1:08:57,  2.24it/s]

7200
0


[0.8559946 0.4420981 0.0066757 0.0066757 0.0066757 0.0066757 0.0066757 0.0066757 0.0066757 0.0066757]:   7%|▋         | 734/10000 [05:27<1:08:40,  2.25it/s]

7332
0


[0.8566015 0.4413203 0.0062347 0.0061125 0.0061125 0.0061125 0.0061125 0.0061125 0.0061125 0.0061125]:   8%|▊         | 818/10000 [06:04<1:06:02,  2.32it/s]

8179
0


[0.8556962 0.4406214 0.0059839 0.0058688 0.0058688 0.0058688 0.0058688 0.0058688 0.0058688 0.0058688]:   9%|▊         | 869/10000 [06:27<1:08:04,  2.24it/s]

8682
0


[0.8565169 0.4405618 0.0059551 0.0058427 0.0058427 0.0058427 0.0058427 0.0058427 0.0058427 0.0058427]:   9%|▉         | 890/10000 [06:36<1:08:05,  2.23it/s]

8899
0


[0.8568565 0.4411392 0.0056962 0.0055907 0.0055907 0.0055907 0.0055907 0.0055907 0.0055907 0.0055907]:   9%|▉         | 948/10000 [07:02<1:07:34,  2.23it/s]

9476
0


[0.8558700 0.4406709 0.0057652 0.0056604 0.0056604 0.0056604 0.0056604 0.0056604 0.0056604 0.0056604]:  10%|▉         | 954/10000 [07:05<1:07:24,  2.24it/s]

9536
0


[0.8559499 0.4410230 0.0058455 0.0057411 0.0057411 0.0057411 0.0057411 0.0057411 0.0057411 0.0057411]:  10%|▉         | 958/10000 [07:07<1:07:12,  2.24it/s]

9577
0


[0.8560417 0.4414583 0.0059375 0.0058333 0.0058333 0.0058333 0.0058333 0.0058333 0.0058333 0.0058333]:  10%|▉         | 960/10000 [07:08<1:07:13,  2.24it/s]

9595
0


[0.8563017 0.4407025 0.0059917 0.0058884 0.0058884 0.0058884 0.0058884 0.0058884 0.0058884 0.0058884]:  10%|▉         | 968/10000 [07:11<1:07:27,  2.23it/s]

9672
0


[0.8562628 0.4405544 0.0059548 0.0058522 0.0058522 0.0058522 0.0058522 0.0058522 0.0058522 0.0058522]:  10%|▉         | 974/10000 [07:14<1:07:14,  2.24it/s]