In [None]:
import os
import pickle

import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset, Subset
from torchtext import *


class MalwareDataset(Dataset):
    def __init__(self, benign_dir="dataset/benign", malware_dir="dataset/malware"):
        self.benign_dir = benign_dir
        self.malware_dir = malware_dir
        self.benign_files = sorted(os.listdir(benign_dir))
        self.malware_files = sorted(os.listdir(malware_dir))

    def __getitem__(self, index):
        try:
            file_dir = os.path.join(self.benign_dir, self.benign_files[index])
            label = 0.0
        except IndexError:
            file_dir = os.path.join(
                self.malware_dir, self.malware_files[index - len(self.benign_files)],
            )
            label = 1.0
        with open(file_dir, "rb") as f:
            file_ = torch.tensor(pickle.load(f))
        return file_, label

    def __len__(self):
        return len(self.benign_files) + len(self.malware_files)


class UniLabelDataset(Dataset):
    def __init__(self, data_dir, is_malware):
        self.data_dir = data_dir
        self.is_malware = is_malware
        self.files = sorted(os.listdir(data_dir))

    def __getitem__(self, index):
        file_dir = os.path.join(self.data_dir, self.files[index])
        with open(file_dir, "rb") as f:
            file_ = torch.tensor(pickle.load(f))
        return file_, float(self.is_malware)

    def __len__(self):
        return len(self.files)




def pad_sequence(sequences, max_len=None, padding_value=0):
    batch_size = len(sequences)
    if max_len is None:
        max_len = max([s.size(0) for s in sequences])
    out_tensor = sequences[0].new_full((batch_size, max_len), padding_value)
    for i, tensor in enumerate(sequences):
        length = tensor.size(0)
        if max_len > length:
            out_tensor[i, :length] = tensor
        else:
            out_tensor[i, :max_len] = tensor[:max_len]
    return out_tensor

def collate_fn(batch):
    xs = pad_sequence([x[0] for x in batch], max_len=4096, padding_value=256)
    ys = torch.tensor([x[1] for x in batch])
    return xs, ys


def train_val_test_split(idx, val_size, test_size):
    tv_idx, test_idx = train_test_split(idx, test_size=test_size, shuffle=True)
    train_idx, val_idx = train_test_split(tv_idx, test_size=val_size, shuffle=True)
    return train_idx, val_idx, test_idx
    
def make_idx(dataset, val_size, test_size):
    num_benign = len(dataset.benign_files)
    num_malware = len(dataset.malware_files)
    benign_idx = range(num_benign)
    malware_idx = range(num_benign, num_benign + num_malware)
    benign_train_idx, benign_val_idx, benign_test_idx = train_val_test_split(
        benign_idx, val_size, test_size
    )
    malware_train_idx, malware_val_idx, malware_test_idx = train_val_test_split(
        malware_idx, val_size, test_size
    )
    train_idx = benign_train_idx + malware_train_idx
    val_idx = benign_val_idx + malware_val_idx
    test_idx = benign_test_idx + malware_test_idx
    return train_idx, val_idx, test_idx



def make_loaders(batch_size, val_size, test_size):
    dataset = MalwareDataset()
    train_idx, val_idx, test_idx = make_idx(dataset, val_size, test_size)
    train_dataset = Subset(dataset, indices=train_idx)
    val_dataset = Subset(dataset, indices=val_idx)
    test_dataset = Subset(dataset, indices=test_idx)
    train_loader = make_loader(train_dataset, batch_size)
    val_loader = make_loader(val_dataset, batch_size)
    test_loader = make_loader(test_dataset, batch_size)
    return train_loader, val_loader, test_loader
    
def make_loader(dataset, batch_size):
    return DataLoader(
        dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True
    )



In [None]:
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from sklearn.metrics import auc, confusion_matrix, roc_curve
from torch import optim
from torch.nn.functional import sigmoid
from tqdm.auto import tqdm


def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def count_params(model, trainable_only=True):
    if trainable_only:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    return sum(p.numel() for p in model.parameters())


def set_plt_style():
    plt.rcParams.update(
        {"text.usetex": True, "font.family": "serif", "font.serif": ["cm"],}
    )


def plot_confusion_matrix(model, test_loader, save_title, device, normalize="all"):
    y_true, y_pred = predict(model, test_loader, device)
    conf_mat = confusion_matrix(y_true, y_pred, normalize=normalize)
    axis_labels = ("Benign", "Malware")
    df = pd.DataFrame(conf_mat, index=axis_labels, columns=axis_labels)
    plot = sns.heatmap(df, annot=True, cmap="Blues")
    plot.figure.savefig(os.path.join("imgs", f"{save_title}_conf_mat.png"), dpi=300)
    plt.close(plot.figure)


def plot_roc_curve(models, test_loader, save_title, device):
    fig, ax = plt.subplots()
    ax.grid(linestyle="--")
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    if isinstance(models, dict):
        for label, model in models.items():
            fpr, tpr, auc_score = _rates_auc(model, test_loader, device)
            ax.plot(fpr, tpr, label=f"{label} ({auc_score:.2f})")
    else:
        fpr, tpr, auc_score = _rates_auc(models, test_loader, device)
        ax.plot(fp, tpr, label=f"{save_title} ({auc_score:.2f})")
    ax.plot([0, 1], [0, 1], linestyle="--", label="Chance (0.5)")
    ax.legend(loc="best")
    fig.savefig(os.path.join("imgs", f"{save_title}_roc.png"), dpi=300)
    plt.close(fig)


def _rates_auc(model, test_loader, device):
    y_true, y_pred = predict(model, test_loader, device, apply_sigmoid=True)
    fpr, tpr, _ = metrics.roc_curve(y_true, y_pred)
    auc_score = auc(fpr, tpr)
    return fpr, tpr, auc_score


def visualize_data(input, title=None):

    input = input.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    input = std * input + mean
    input = np.clip(input, 0, 1)
    plt.imshow(input)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) 

inputs_data, classes = next(iter(loaders_data['train_data']))

out = torchvision.utils.make_grid(inputs_data)


visualize_data(out, title=[class_names[x] for x in classes])


def get_accuracy(model, data_loader, device):
    y_true, y_pred = predict(model, data_loader, device, to_numpy=False)
    return 100 * (y_true == y_pred).to(float).mean().item()


def plot_train_history(train_loss_history, val_loss_history, save_title):
    fig, ax = plt.subplots()
    time_ = range(len(train_loss_history))
    ax.set_xlabel("Epochs")
    ax.set_ylabel("BCE Loss")
    ax.grid(linestyle="--")
    ax.plot(time_, train_loss_history, color="blue", label="train loss")
    ax.plot(time_, val_loss_history, color="red", label="val loss")
    ax.legend(loc="best")
    fig.savefig(os.path.join("figures", f"{save_title}_train_history.png"), dpi=300)
    plt.close(fig)


class EarlyStopMonitor:
    def __init__(self, patience, mode="min"):
        assert mode in {"min", "max"}, "`mode` must be one of 'min' or 'max'"
        self.log = []
        self.mode = mode
        self.count = 0
        self.patience = patience

    def step(self, metric):
        if not self.log:
            self.log.append(metric)
            return False
        flag = metric > self.log[-1]
        if flag == (self.mode == "min"):
            self.count += 1
        else:
            self.count = 0
        self.log.append(metric)
        return self.count > self.patience


In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

In [None]:
def model_training(res_model, criterion, optimizer, scheduler, number_epochs=25):
    since = time.time()

    best_resmodel_wts = copy.deepcopy(res_model.state_dict())
    best_accuracy = 0.0

    for epochs in range(number_epochs):
        print('Epoch {}/{}'.format(epochs, number_epochs - 1))
        print('-' * 10)

        for phase in ['train_data', 'validation_data']: 
            if phase == 'train_data':
                res_model.train()  
            else:
                res_model.eval()   
            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in loaders_data[phase]: 
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad() 

                with torch.set_grad_enabled(phase == 'train_data'):
                    outputs = res_model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train': 
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / sizes_datasets[phase]
            epoch_acc = running_corrects.double() / sizes_datasets[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            if phase == 'val' and epoch_acc > best_acc: 
                best_accuracy = epoch_acc
                best_resmodel_wts = copy.deepcopy(res_model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_accuracy))

   
    res_model.load_state_dict(best_resmodel_wts)
    return res_model

In [None]:
def run_epoch(model,iterator,criterion):
    
    epoch_loss = 0.0
    epoch_acc = 0.0
    
    model.eval()
    
    with torch.no_grad():
        for batch in iterator:
            text,text_lengths = batch.url
            
            predictions = model(text,text_lengths).squeeze()
              
            loss = criterion(predictions, batch.type)
            acc = binary_accuracy(predictions, batch.type)
            
            epoch_loss += loss.item()
            epoch_acc += acc.item()