In [12]:
from pathlib import Path

import pandas as pd
from sklearn.model_selection import train_test_split
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer, LongformerTokenizerFast
from tqdm import tqdm

In [13]:
NOTEBOOKS_DIR = Path().resolve()
PROJECT_DIR = NOTEBOOKS_DIR.parent
DATA_DIR = PROJECT_DIR / 'data'
DATASET_DIR = DATA_DIR / 'pe-machine-learning-dataset'
REPORTS_DIR = DATASET_DIR / 'reports'
RANDOM_STATE = 741

In [14]:
import sys

sys.path.append(str(PROJECT_DIR))

In [15]:
from src import parsers
from src.extractor import VirusTotalFeatureExtractor

In [16]:
df = pd.read_parquet(DATA_DIR / 'dataset_with_reports.parquet')

In [17]:
df.shape

(20362, 12)

In [21]:
df.list.value_counts()

list
Blacklist    10785
Whitelist     9577
Name: count, dtype: int64

In [7]:
df_train, df_test = train_test_split(
    df,
    test_size=0.2,
    random_state=RANDOM_STATE,
    shuffle=True,
    stratify=df.list,
)

df_train, df_valid = train_test_split(
    df_train,
    test_size=0.1,
    random_state=RANDOM_STATE,
    shuffle=True,
    stratify=df_train.list,
)

In [8]:
tokenizer = LongformerTokenizerFast.from_pretrained('kazzand/ru-longformer-tiny-16384')
#tokenizer = LongformerTokenizerFast.from_pretrained('cointegrated/rubert-tiny2')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
    tokenizer, chunk_size=10_240, chunk_overlap=5_120,
)

In [10]:
def extract_texts(extractor) -> list[str]:
    texts = []
    try:
        texts.extend(parsers.MagicParser().transform(extractor.magic))
    except:
        pass
    try:
        texts.extend(parsers.TypeTagParser().transform(extractor.type_tag))
    except:
        pass
    try:
        texts.extend(parsers.TypeTagsParser().transform(extractor.type_tags))
    except:
        pass
    try:
        texts.extend(parsers.DetectitEasyParser().transform(extractor.detectiteasy))
    except:
        pass
    try:
        texts.extend(parsers.TypeExtensionParser().transform(extractor.type_extension))
    except:
        pass
    try:
        texts.extend(parsers.ImportListParser().transform(extractor.import_list))
    except:
        pass
    try:    
        texts.extend(parsers.MitreAttackTechniquesParser().transform(extractor.mitre_attack_techniques))
    except:
        pass
    try:
        texts.extend(parsers.SignatureMatchesParser().transform(extractor.signature_matches))
    except:
        pass
    try:
        texts.extend(parsers.CommandExecutionsParser().transform(extractor.command_executions))
    except:
        pass
    try:
        texts.extend(parsers.ProcessesTreeParser().transform(extractor.processes_tree))
    except:
        pass
    try:
        texts.extend(parsers.ProcessesInjectedParser().transform(extractor.processes_injected))
    except:
        pass
    try:
        texts.extend(parsers.ProcessesCreatedParser().transform(extractor.processes_created))
    except:
        pass
    try:
        texts.extend(parsers.ProcessesTerminatedParser().transform(extractor.processes_terminated))
    except:
        pass
    try:
        texts.extend(parsers.FilesOpenedParser().transform(extractor.files_opened))
    except:
        pass
    try:
        texts.extend(parsers.FilesCopiedParser().transform(extractor.files_copied))
    except:
        pass
    try:
        texts.extend(parsers.FilesDroppedParser().transform(extractor.files_dropped))
    except:
        pass
    try:
        texts.extend(parsers.FilesWrittenParser().transform(extractor.files_written))
    except:
        pass
    try:
        texts.extend(parsers.FilesAttributeChangedParser().transform(extractor.files_attribute_changed))
    except:
        pass
    try:
        texts.extend(parsers.MutexesOpenedParser().transform(extractor.mutexes_opened))
    except:
        pass
    try:
        texts.extend(parsers.MutexesCreatedParser().transform(extractor.mutexes_created))
    except:
        pass
    try:
        texts.extend(parsers.ModulesLoadedParser().transform(extractor.modules_loaded))
    except:
        pass
    try:
        texts.extend(parsers.RegistryKeysOpenedParser().transform(extractor.registry_keys_opened))
    except:
        pass
    try:
        texts.extend(parsers.RegistryKeysSetParser().transform(extractor.registry_keys_set))
    except:
        pass
    try:
        texts.extend(parsers.RegistryKeysDeletedParser().transform(extractor.registry_keys_deleted))
    except:
        pass
    try:
        texts.extend(parsers.IpTrafficParser().transform(extractor.ip_traffic))
    except:
        pass
    try:
        texts.extend(parsers.DNSLookupsParser().transform(extractor.dns_lookups))
    except:
        pass
    try:
        texts.extend(parsers.ServicesStartedParser().transform(extractor.services_started))
    except:
        pass
    try:
        texts.extend(parsers.ServicesOpenedParser().transform(extractor.services_opened))
    except:
        pass
    try:
        texts.extend(parsers.CallsHighlightedParser().transform(extractor.calls_highlighted))
    except:
        pass
    try:
        texts.extend(parsers.HTTPConversationsParser().transform(extractor.http_conversations))
    except:
        pass
    try:
        texts.extend(parsers.SignalsHookedParser().transform(extractor.signals_hooked))
    except:
        pass
    try:
        texts.extend(parsers.WindowsSearchedParser().transform(extractor.windows_searched))
    except:
        pass
    return texts

In [11]:
for _df, _filename in (
    (df_train, 'df_train_chunks.parquet'),
    (df_valid, 'df_valid_chunks.parquet'),
    (df_test, 'df_test_chunks.parquet'),
):
    container_hash_text = {
        'FILENAME': [],
        'HASH': [],
        'TEXT': [],
        'LABEL': [],
        'LABEL_ID': [],
    }
    
    for idx, row in tqdm(_df.iterrows(), total=_df.shape[0]):
        report_path = REPORTS_DIR / f'{row.sha256}.json'
        
        extractor = VirusTotalFeatureExtractor.from_json(report_path)
        corpus = '\n'.join(extract_texts(extractor))
        chunks = text_splitter.split_text(corpus)
    
        label = 'malware' if row.list == 'Blacklist' else 'benign'
        label_id = 1 if row.list == 'Blacklist' else 0
    
        container_hash_text['FILENAME'].extend([row.id] * len(chunks))
        container_hash_text['HASH'].extend([row.sha256] * len(chunks))
        container_hash_text['TEXT'].extend(chunks)
        container_hash_text['LABEL'].extend([label] * len(chunks))
        container_hash_text['LABEL_ID'].extend([label_id] * len(chunks))

    _df = pd.DataFrame(container_hash_text)
    for _ in range(10):
        _df = _df.sample(frac=1, random_state=RANDOM_STATE, ignore_index=True)

    _df.to_parquet(DATA_DIR / _filename)


  2%|██▎                                                                                                          | 307/14660 [00:17<11:59, 19.96it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (19041 > 16384). Running this sequence through the model will result in indexing errors
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 14660/14660 [15:21<00:00, 15.92it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1629/1629 [01:36<00:00, 16.83it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4073/4073 [04:11<00:00, 16.17it/s]
