### Imports

In [None]:
import torch
import torch.nn.functional as F
import torchtext
import time
import random
import pandas as pd
import json
from tqdm import tqdm
import os
import re

from torchtext.data.utils import get_tokenizer
from torch.utils.data import DataLoader

from sqlalchemy.orm import Session
from sqlalchemy import create_engine, select, MetaData, Table
from sqlalchemy.ext.automap import automap_base

from utils.batch import collate_batch, BatchSamplerSimilarLength
from utils.train import train_model, compute_accuracy
from utils.plot import plot_accuracy, plot_training_loss

from models.lstm import LSTM

### Settings and Hyperparameters

In [None]:
with open ('../settings.json') as f:
    settings = json.load(f)

db_uri = settings['sqlalchemy_database_uri']

RANDOM_SEED = 123
torch.manual_seed(RANDOM_SEED)

VOCABULARY_SIZE = 20000
LEARNING_RATE = 0.005
BATCH_SIZE = 8
NUM_EPOCHS = 15
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

EMBEDDING_DIM = 256
HIDDEN_DIM = 512
NUM_CLASSES = 3

### Define Tokenizer

In [None]:
def custom_tokenizer(line):
    line = line.lower()
    line = line.replace(',', ' ')
    line = line.replace('\\', ' ')
    line = line.replace('\\\\', ' ')
    return line.split()

tokenizer = get_tokenizer(tokenizer=custom_tokenizer)

### Load Data

In [None]:
# check if json files exist
if not os.path.isfile('reports.json'):
    # connect to database
    engine = create_engine(db_uri)

    # load tables
    metadata_obj = MetaData()
    conn = engine.connect()

    Tag = Table('tag', metadata_obj, autoload_with=engine)
    SampleTag = Table('sample_tag', metadata_obj, autoload_with=engine)
    Analysis = Table('analysis', metadata_obj, autoload_with=engine)

    # Start a session
    session = Session(engine)

    # get all reports with the tag 'file_search_dfs'
    stmt = select(Analysis.c.report).join(SampleTag, Tag.c.id == SampleTag.c.tag_id).join(
        Analysis, SampleTag.c.sample_id == Analysis.c.sample).where(
        Tag.c.value == 'file_search_dfs'
    )

    results = session.execute(stmt).fetchall()
    dfs_report_paths = [r[0] for r in results] 

    # get all reports with the tag 'file_search_bfs'
    stmt = select(Analysis.c.report).join(SampleTag, Tag.c.id == SampleTag.c.tag_id).join(
        Analysis, SampleTag.c.sample_id == Analysis.c.sample).where(
        Tag.c.value == 'file_search_bfs'
    )

    results = session.execute(stmt).fetchall()
    bfs_report_paths = [r[0] for r in results]

    # get all reports with the tag 'benign'
    stmt = select(Analysis.c.report).join(SampleTag, Tag.c.id == SampleTag.c.tag_id).join(
        Analysis, SampleTag.c.sample_id == Analysis.c.sample).where(
        Tag.c.value == 'benign'
    )

    results = session.execute(stmt).fetchall()
    benign_report_paths = [r[0] for r in results]

    # Close the session
    session.close()

    # fetch reports
    bfs_reports = []
    for report_path in tqdm(bfs_report_paths, desc="Reading BFS reports"):
        with open(report_path) as f:
            bfs_reports.append(f.read())

    dfs_reports = []
    for report_path in tqdm(dfs_report_paths, desc="Reading DFS reports"):
        with open(report_path) as f:
            dfs_reports.append(f.read())

    benign_reports = []
    for report_path in tqdm(benign_report_paths, desc="Reading benign reports"):
        with open(report_path) as f:
            benign_reports.append(f.read())

    # combine reports
    reports = [[r, 'bfs'] for r in bfs_reports] + [[r, 'dfs'] for r in dfs_reports] + [[r, 'benign'] for r in benign_reports]

    # shuffle reports
    random.shuffle(reports)

    # Tokenize reports
    i = 0
    for report in tqdm(reports, desc="Tokenizing reports"):
        dynamic_report = json.loads(report[0])['dynamic']
        dynamic_report_tokenized = []
        for item in dynamic_report:
            line = f"{item['Operation']}, {item['Path']}, {item['Result']}"
            dynamic_report_tokenized.extend(tokenizer(line))
        reports[i][0] = dynamic_report_tokenized
        i += 1

    # json dump reports to file
    print("Dumping reports to file")
    with open('reports.json', 'w') as f:
        json.dump(reports, f)

else:
    print("Loading reports from file")
    # load reports from file
    with open('reports.json') as f:
        reports = json.load(f)

### Split Data

In [None]:
# Split into training, validation, and test sets
from torchdata.datapipes.iter import IterableWrapper
dp = IterableWrapper(reports)

# Get the number of rows in dataset
N_ROWS = len(list(dp)) 
N_train = int(N_ROWS * 0.8)
N_valid = int(N_ROWS * 0.1)
N_test = N_ROWS - N_train - N_valid

# Split into training and val datapipes early on. Will build vocabulary from training datapipe only.
train_dp, valid_dp, test_dp = dp.random_split(total_length=N_ROWS, weights={"train": N_train, "valid": N_valid, "test": N_test}, seed=RANDOM_SEED)

print(f'Num Train: {len(train_dp)}')
print(f'Num Validate: {len(valid_dp)}')
print(f'Num Test: {len(test_dp)}')

### Build Vocabulary

In [None]:
# build vocab
from torchtext.vocab import build_vocab_from_iterator
def yield_tokens(data_iter):
    for text, _ in data_iter:
        yield text

vocab = build_vocab_from_iterator(yield_tokens(train_dp), specials=["<unk>", "<pad>"], max_tokens=VOCABULARY_SIZE)
vocab.set_default_index(vocab["<unk>"])
PADDING_VALUE=vocab['<PAD>']

### Define text and label transforms

In [None]:
text_transform = lambda x: [vocab[token] for token in x]
label_transform = lambda x: 0 if x == 'dfs' else 1 if x == 'bfs' else 2

# Print out the output of text_transform
print("input to the text_transform:", "here is an example")
print("output of the text_transform:", text_transform(list(train_dp)[0][0]))

In [None]:

def collate_batch_wrapper(batch):
    return collate_batch(batch=batch, 
                  padding_value=PADDING_VALUE, 
                  device=DEVICE, 
                  text_transform=text_transform, 
                  label_transform=label_transform)

train_dp_list = list(train_dp)
valid_dp_list = list(valid_dp)
test_dp_list = list(test_dp)

train_loader = DataLoader(train_dp_list, 
                          batch_sampler=BatchSamplerSimilarLength(dataset = train_dp_list, batch_size=BATCH_SIZE),
                          collate_fn=collate_batch_wrapper)
valid_loader = DataLoader(train_dp_list, 
                          batch_sampler=BatchSamplerSimilarLength(dataset = valid_dp_list, batch_size=BATCH_SIZE, shuffle=False),
                          collate_fn=collate_batch_wrapper)
test_loader = DataLoader(train_dp_list, 
                          batch_sampler=BatchSamplerSimilarLength(dataset = test_dp_list, batch_size=BATCH_SIZE, shuffle=False),
                          collate_fn=collate_batch_wrapper)

text_batch, label_batch = next(iter(train_loader))
print(text_batch.size())
print(label_batch.size())

In [None]:
print('Train')
for text_batch, label_batch in train_loader:
    print(f'Text matrix size: {text_batch.size()}')
    print(f'Target vector size: {label_batch.size()}')
    break
    
print('\nValid:')
for text_batch, label_batch in valid_loader:
    print(f'Text matrix size: {text_batch.size()}')
    print(f'Target vector size: {label_batch.size()}')
    break
    
print('\nTest:')
for text_batch, label_batch in test_loader:
    print(f'Text matrix size: {text_batch.size()}')
    print(f'Target vector size: {label_batch.size()}')
    break

### Train Model

In [None]:
model = LSTM(input_dim=len(vocab),
             embedding_dim=EMBEDDING_DIM,
             hidden_dim=HIDDEN_DIM,
             output_dim=NUM_CLASSES # could use 1 for binary classification
)

model = model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       factor=0.1,
                                                       mode='max',
                                                       verbose=True)

minibatch_loss_list, train_acc_list, valid_acc_list = train_model(
    model=model,
    num_epochs=NUM_EPOCHS,
    train_loader=train_loader,
    valid_loader=valid_loader,
    test_loader=test_loader,
    optimizer=optimizer,
    device=DEVICE,
    logging_interval=100)

In [None]:
plot_training_loss(minibatch_loss_list=minibatch_loss_list,
                   num_epochs=NUM_EPOCHS,
                   iter_per_epoch=len(train_loader),
                   results_dir=None,
                   averaging_iterations=100)

plot_accuracy(train_acc_list=train_acc_list,
              valid_acc_list=valid_acc_list,
              results_dir=None)

In [None]:
# save model, vocab, and optimizer state
torch.save(model.state_dict(), 'model_data/lstm_02.pt')
torch.save(vocab, 'model_data/vocab_02.pt')
torch.save(optimizer.state_dict(), 'model_data/optimizer_02.pt')