In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import multiprocessing as mp
from datetime import datetime
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
# from src.model import *
# from src.util import *
os.environ["CUDA_VISIBLE_DEVICES"] = '3'

In [2]:
trained_model_path = '/workdir/security/home/junjiehuang2468/paper/trained_models_weight/kaggle_miscrosoft/'
data_path = "/workdir/security/home/junjiehuang2468/paper/data/kaggle/"
train_data_path = data_path + "malwares/"  # Training data
train_label_path = data_path + "train_labels.csv"  # Training label

In [3]:
CUDA = True if torch.cuda.is_available() else False
NUM_WORKERS = 16  # Number of cores to use for data loader
BATCH_SIZE = 128  #
LEAVE_BIT_NUMBER = 500000
KERNEL_SIZE = 500  # Kernel size & stride for Malconv (defualt : 500)

In [4]:
trainset = pd.read_csv(data_path + 'train_dataset.csv')
validset = pd.read_csv(data_path + 'valid_dataset.csv')

In [5]:
class ExeDataset(Dataset):
    def __init__(self, malware_names, data_path, labels, leave_bit_num):
        self.malware_names = malware_names
        self.data_path = data_path
        self.labels = labels
        self.leave_bit_num = leave_bit_num

    def __len__(self):
        return len(self.malware_names)

    def __getitem__(self, idx):
        with open(self.data_path + self.malware_names[idx] + '.txt','rb') as fp:
            data = [bit+1 for bit in fp.read()[:self.leave_bit_num]]
            padding = [0]*(self.leave_bit_num-len(data))
            data = data + padding

        return np.array(data), np.array([self.labels[idx]])

In [6]:
train_dataset = ExeDataset(
    trainset["id"].tolist(), 
    train_data_path, 
    trainset["labels"].tolist(), 
    LEAVE_BIT_NUMBER
)
valid_dataset = ExeDataset(
    validset["id"].tolist(), 
    train_data_path, 
    validset["labels"].tolist(), 
    LEAVE_BIT_NUMBER
)

In [7]:
trainloader = DataLoader(
    dataset = train_dataset,
    batch_size = BATCH_SIZE,
    shuffle = False,
    num_workers = NUM_WORKERS,
    pin_memory = True
)
validloader = DataLoader(
    dataset = valid_dataset,
    batch_size = BATCH_SIZE,
    shuffle = False,
    num_workers = NUM_WORKERS,
    pin_memory = True
)

In [8]:
# class MalConv(nn.Module):
#     def __init__(self, input_length=2000000, window_size=500):
#         super(MalConv, self).__init__()

#         self.embed = nn.Embedding(257, 8, padding_idx=0)

#         self.conv_1 = nn.Conv1d(4, 128, window_size, stride=window_size, bias=True)
#         self.conv_2 = nn.Conv1d(4, 128, window_size, stride=window_size, bias=True)

#         self.BatchNorm1d = nn.BatchNorm1d(128)

#         self.pooling = nn.MaxPool1d(int(input_length / window_size))

#         self.fc_1 = nn.Linear(128, 128)
#         self.fc_2 = nn.Linear(128, 9)

#         # self.BatchNorm1d = nn.BatchNorm1d(128)

#         self.sigmoid = nn.Sigmoid()

#     # self.softmax = nn.Softmax()

#     def forward(self, x):
#         x = self.embed(x)
#         # Channel first
#         x = torch.transpose(x, -1, -2)

#         cnn_value = self.conv_1(x.narrow(-2, 0, 4))
#         cnn_value = self.BatchNorm1d(cnn_value)
#         gating_weight = self.sigmoid(self.conv_2(x.narrow(-2, 4, 4)))

#         x = cnn_value * gating_weight
#         x = self.pooling(x)

#         x = x.view(-1, 128)
#         x = self.fc_1(x)
#         # x = self.BatchNorm1d(x)
#         x = self.fc_2(x)
#         # x = self.sigmoid(x)

#         return x

In [9]:
class Model(nn.Module):
    def __init__(self, data_length = 2e6, kernel_size = 500):
        super().__init__()
        self.embedding = nn.Embedding(257, 8, padding_idx=0)
        self.conv_layer_1 = nn.Conv1d(4, 128, kernel_size, stride = kernel_size, bias = True)
        # self.bn_1 = nn.BatchNorm1d(128)
        self.conv_layer_2 = nn.Conv1d(4, 128, kernel_size, stride = kernel_size, bias = True)
        self.pool_layer_2 = nn.MaxPool1d(data_length//kernel_size)
        self.fc_layer_3 = nn.Linear(128, 128)
        self.fc_layer_4 = nn.Linear(128, 9)
        
    def forward(self,x):
        x = self.embedding(x)
        x = x.transpose(-1,-2)
        x_conv_1 = self.conv_layer_1(x[:,:4,:])
        x_conv_2 = torch.sigmoid(self.conv_layer_2(x[:,4:,:]))
        x = x_conv_1*x_conv_2
        del x_conv_1,x_conv_2
        x = self.pool_layer_2(x).squeeze()
        x = self.fc_layer_3(x)
        x = self.fc_layer_4(x)
        # x = torch.sigmoid(x)
        return x

In [10]:
# def mp_func(i,inpu,te,gr):
#     check = 0
#     for j,(inp,g,t) in enumerate(zip(inpu,gr,te)):
#         if inp != 0: 
#             check = j
#             continue
#         max_idx = np.argmin(g).tolist()
#         org_max_idx = np.argmax(t).tolist()
#         if g[max_idx] > 0: continue
#         te[j][org_max_idx] = 0
#         te[j][max_idx] = 1
#     return [i,te,check]

In [11]:
# class MalConv(nn.Module):
#     def __init__(self, input_length=2000000, window_size=500):
#         super(MalConv, self).__init__()

#         self.embed = nn.Embedding(257, 8, padding_idx=0)

#         self.conv_1 = nn.Conv1d(4, 128, window_size, stride=window_size, bias=True)
#         self.conv_2 = nn.Conv1d(4, 128, window_size, stride=window_size, bias=True)

#         self.BatchNorm1d = nn.BatchNorm1d(128)

#         self.pooling = nn.MaxPool1d(int(input_length / window_size))

#         self.fc_1 = nn.Linear(128, 128)
#         self.fc_2 = nn.Linear(128, 9)

#         # self.BatchNorm1d = nn.BatchNorm1d(128)

#         self.sigmoid = nn.Sigmoid()

#     # self.softmax = nn.Softmax()
    
#     def forward(self, input_, loss_fn, fake_label, label):
#         temp = F.one_hot(input_,num_classes=257).float()
#         temp.requires_grad = True
#         temp.retain_grad()
#         for _ in range(10):
#             x = temp @ self.embed.weight
#             x = torch.transpose(x, -1, -2)
#             cnn_value = self.conv_1(x.narrow(-2, 0, 4))
#             cnn_value = self.BatchNorm1d(cnn_value)
#             gating_weight = self.sigmoid(malconv.conv_2(x.narrow(-2, 4, 4)))
#             x = cnn_value * gating_weight
#             x = self.pooling(x)
#             x = x.view(-1, 128)
#             x = self.fc_1(x)
#             x = self.fc_2(x)
            
#             print((torch.argmax(torch.softmax(x,dim=-1),dim=-1) == label).float().mean())
            
#             loss = loss_fn(x,fake_label).cuda()
#             print(loss)
#             loss.backward()
            
#             data = [(i,inpu,te,gr) for i,(inpu,te,gr) in enumerate(zip(
#                 input_.detach().cpu().numpy(),
#                 temp.detach().cpu().numpy(),
#                 temp.grad.detach().cpu().numpy()
#             ))]
#             with mp.Pool(processes=24 if len(data) > 24 else len(data)) as pool:
#                 results = pool.starmap(mp_func,data)
            
#             print(sum(r[2] for r in results)/len(results))
#             results = sorted(results,key = lambda x: x[0])
#             for i in range(len(temp)):
#                 temp.data[i] = torch.tensor(results[i][1], dtype=torch.float, requires_grad=True).cuda()
                
#         return x.cpu().detach().numpy(),temp.cpu().detach().numpy()

In [12]:
def train_def(model,trainloader,loss_fn,optim,cuda=True):
    model.train()
    ls = []
    bar = tqdm(trainloader)
    for step, (batch_data,batch_label) in enumerate(bar):
        optim.zero_grad()
        batch_data = batch_data.cuda() if cuda else batch_data
        batch_label = batch_label.cuda() if cuda else batch_label
        batch_label = batch_label.squeeze() - 1

        pred = model(batch_data)
        loss = loss_fn(pred, batch_label)
        loss.backward()
        optim.step()
        _, predicted = torch.max(pred, 1)
        temp_ls = (batch_label.cpu().data.numpy() == predicted.cpu().data.numpy()).tolist()
        ls.extend(temp_ls)
        bar.set_description(f'train: {np.mean(ls):.6}')
    return model

In [13]:
def valid_def(model,validloader,cuda=True):
    model.eval()
    ls = []
    bar = tqdm(validloader)
    for step, (batch_data,batch_label) in enumerate(bar):
        optim.zero_grad()
        batch_data = batch_data.cuda() if cuda else batch_data
        batch_label = batch_label.cuda() if cuda else batch_label
        batch_label = batch_label.squeeze() - 1

        pred = model(batch_data)
        _, predicted = torch.max(pred, 1)
        temp_ls = (batch_label.cpu().data.numpy() == predicted.cpu().data.numpy()).tolist()
        ls.extend(temp_ls)
        bar.set_description(f'test: {np.mean(ls):.6}')
    return model,np.mean(ls)

In [14]:
model = Model(data_length=LEAVE_BIT_NUMBER,kernel_size=KERNEL_SIZE)

ce_loss = nn.CrossEntropyLoss()
optim = Adam(model.parameters())

model = model.cuda() if CUDA else model
ce_loss = ce_loss.cuda() if CUDA else ce_less

In [15]:
time_dir = str(datetime.now())
time_dir = time_dir[:time_dir.rfind(':')]
os.mkdir(f'{trained_model_path}{time_dir}')

In [16]:
for i in range(25):
    print(i)
    model = train_def(model,trainloader,ce_loss,optim,CUDA)
    model,test_acc = valid_def(model,validloader,CUDA)
    save_path = f'{trained_model_path}{time_dir}/50w_epoch:{i}_test_acc:{test_acc:.6f}.pt'
    torch.save(model.state_dict(),save_path)

0


train: 0.742121: 100%|██████████| 68/68 [01:31<00:00,  1.35s/it]
test: 0.905704: 100%|██████████| 17/17 [00:29<00:00,  1.72s/it]


1


train: 0.933517: 100%|██████████| 68/68 [01:42<00:00,  1.51s/it]
test: 0.912144: 100%|██████████| 17/17 [00:33<00:00,  1.95s/it]


2


train: 0.971935: 100%|██████████| 68/68 [01:39<00:00,  1.47s/it]
test: 0.934683: 100%|██████████| 17/17 [00:31<00:00,  1.83s/it]


3


train: 0.989878: 100%|██████████| 68/68 [01:37<00:00,  1.44s/it]
test: 0.935603: 100%|██████████| 17/17 [00:29<00:00,  1.72s/it]


4


train: 0.990683: 100%|██████████| 68/68 [01:43<00:00,  1.52s/it]
test: 0.925023: 100%|██████████| 17/17 [00:31<00:00,  1.84s/it]


5


train: 0.995629: 100%|██████████| 68/68 [01:31<00:00,  1.35s/it]
test: 0.940662: 100%|██████████| 17/17 [00:39<00:00,  2.33s/it]


6


train: 0.995514: 100%|██████████| 68/68 [01:43<00:00,  1.52s/it]
test: 0.937443: 100%|██████████| 17/17 [00:32<00:00,  1.94s/it]


7


train: 0.997815: 100%|██████████| 68/68 [01:41<00:00,  1.49s/it]
test: 0.953082: 100%|██████████| 17/17 [00:32<00:00,  1.90s/it]


8


train: 0.9977: 100%|██████████| 68/68 [01:32<00:00,  1.36s/it]  
test: 0.958142: 100%|██████████| 17/17 [00:32<00:00,  1.93s/it]


9


train: 0.9977: 100%|██████████| 68/68 [01:41<00:00,  1.49s/it]  
test: 0.954462: 100%|██████████| 17/17 [00:31<00:00,  1.88s/it]


10


train: 0.998045: 100%|██████████| 68/68 [01:39<00:00,  1.46s/it]
test: 0.945722: 100%|██████████| 17/17 [00:30<00:00,  1.80s/it]


11


train: 0.997239: 100%|██████████| 68/68 [01:40<00:00,  1.47s/it]
test: 0.946182: 100%|██████████| 17/17 [00:30<00:00,  1.77s/it]


12


train: 0.99793: 100%|██████████| 68/68 [01:36<00:00,  1.42s/it] 
test: 0.949402: 100%|██████████| 17/17 [00:33<00:00,  1.97s/it]


13


train: 0.99747: 100%|██████████| 68/68 [01:32<00:00,  1.35s/it] 
test: 0.941122: 100%|██████████| 17/17 [00:33<00:00,  1.99s/it]


14


train: 0.998275: 100%|██████████| 68/68 [01:46<00:00,  1.57s/it]
test: 0.961362: 100%|██████████| 17/17 [00:32<00:00,  1.91s/it]


15


train: 0.99839: 100%|██████████| 68/68 [01:41<00:00,  1.49s/it] 
test: 0.961822: 100%|██████████| 17/17 [00:32<00:00,  1.90s/it]


16


train: 0.998505: 100%|██████████| 68/68 [01:37<00:00,  1.44s/it]
test: 0.960902: 100%|██████████| 17/17 [00:36<00:00,  2.16s/it]


17


train: 0.998735: 100%|██████████| 68/68 [01:49<00:00,  1.61s/it]
test: 0.960442: 100%|██████████| 17/17 [00:31<00:00,  1.84s/it]


18


train: 0.998735: 100%|██████████| 68/68 [01:44<00:00,  1.53s/it]
test: 0.960442: 100%|██████████| 17/17 [00:34<00:00,  2.04s/it]


19


train: 0.998735: 100%|██████████| 68/68 [01:39<00:00,  1.46s/it]
test: 0.959982: 100%|██████████| 17/17 [00:35<00:00,  2.09s/it]


20


train: 0.998735: 100%|██████████| 68/68 [01:44<00:00,  1.53s/it]
test: 0.960902: 100%|██████████| 17/17 [00:34<00:00,  2.01s/it]


21


train: 0.998735: 100%|██████████| 68/68 [01:47<00:00,  1.58s/it]
test: 0.959982: 100%|██████████| 17/17 [00:35<00:00,  2.07s/it]


22


train: 0.998735: 100%|██████████| 68/68 [01:46<00:00,  1.57s/it]
test: 0.960442: 100%|██████████| 17/17 [00:34<00:00,  2.03s/it]


23


train: 0.998735: 100%|██████████| 68/68 [01:36<00:00,  1.42s/it]
test: 0.958602: 100%|██████████| 17/17 [00:37<00:00,  2.20s/it]


24


train: 0.998735: 100%|██████████| 68/68 [01:35<00:00,  1.40s/it]
test: 0.959982: 100%|██████████| 17/17 [00:34<00:00,  2.01s/it]
