# 4. MalRNN을 이용한 변종 문서형 Stream Data 생성
* 필요한 package 다운로드 확인 및 적용

In [None]:
# !pip install torch
# !pip install numpy

In [None]:
import numpy as np
import torch
from rnn_model import CharRNN

## 1. Stream byte 생성 함수 구축

In [None]:
def generate_byte(model, base_stream, device, len_to_predict=1000, temperature=0.8):
    hidden_state = model.init_hidden(1).to(device)
    base_input = torch.LongTensor(base_stream).unsqueeze(0).to(device)
    predict = base_stream

    for p in range(len(base_stream) - 1):
        _, hidden_state = model(base_input[:, p], hidden_state)

    output_result = []
    model_input = base_input[:, -1]
    for p in range(len_to_predict):
        output, hidden_state = model(model_input, hidden_state)
        output_result.append(output)

        output_dist = output.data.view(-1).div(temperature).exp()
        predict_stream = torch.multinomial(output_dist, 1)[0]

        predict = np.append(predict, predict_stream.detach().cpu())
        model_input = (
            torch.tensor(predict_stream, dtype=torch.long).unsqueeze(0).to(device)
        )

    return predict.tolist(), output_result

## 2. 변종 문서형 Byte Stream 생성 과정 구축

In [None]:
with open("../data/4_critical_example.txt", "r") as f:
    critical_base = f.read().split(",")
    critical_base = [int(x) for x in critical_base]

In [None]:
### byte stream 생성 과정 구축 ###
# TODO
# 1. 학습장치 할당
# 2. model 호출 및 weight load 후 모델의 학습장치 할당
# 3. generate_byte 함수를 통한 byte stream 생성
# 4. 생성된 byte를 합친 변정 byte stream 생성
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model = CharRNN(
    input_size=256,
    hidden_size = 100,
    output_size=256,
    model="gru",
    n_layers=1
)
model = torch.load("../data/malRNN_doc.pt")
model.to(device)

new_byte = generate_byte(model=model, base_stream=critical_base, device=device)
generated_byte = bytearray(critical_base) + bytearray(new_byte[0])

## 3. 변종 Byte Stream 탐지 회피 여부 확인

In [None]:
from MalConv import MalConv
import torch.nn.functional as F

In [None]:
def detect(malconv, stream_data):
    stream_data = torch.from_numpy(np.frombuffer(bytearray(stream_data), dtype=np.uint8)[np.newaxis, :])
    output = malconv(stream_data)
    output = F.softmax(output, dim=-1).detach().numpy()[0,1]
    return output

malconv = MalConv(channels=256, window_size=512, embd_size=8)
malconv_weight = torch.load("../data/malconv_doc.pth")
malconv.load_state_dict(malconv_weight)

In [None]:
detect(malconv, critical_base)

In [None]:
detect(malconv, generated_byte)