In [1]:
from pathlib import Path

import torch
from transformers import (
    ResNetConfig,
    ResNetModel,
    AutoModel,
    PreTrainedModel,
    PretrainedConfig,
    LongformerModel,
    LongformerConfig,
    LongformerTokenizerFast,
    ConvNextImageProcessor,
)

In [2]:
NOTEBOOKS_DIR = Path().resolve()
PROJECT_DIR = NOTEBOOKS_DIR.parent
MODEL_DIR = PROJECT_DIR / 'model'

In [3]:
class MalwareDetectionModel(torch.nn.Module):
    def __init__(
            self,
            num_classes: int = 2,
    ) -> None:
        super(MalwareDetectionModel, self).__init__()
        resnet_config = ResNetConfig(
            num_channels=1,
            torch_dtype=torch.float32,
            depths=[2, 2, 2, 2],
            hidden_sizes=[64, 128, 256, 512],
            layer_type='basic',
            model_type='resnet',
        )
        self.resnet_model = ResNetModel(resnet_config)
        self.longformer_model = AutoModel.from_pretrained('kazzand/ru-longformer-tiny-16384')

        in_features = self.resnet_model.config.hidden_sizes[-1] + self.longformer_model.config.hidden_size

        self.linear = torch.nn.Linear(in_features, 312)
        self.fc = torch.nn.Linear(312, num_classes)


    def forward(
            self,
            longformer_input_ids,
            longformer_attention_mask,
            longformer_global_attention_mask,
            resnet_pixel_values,
    ):
        resnet_pooler_output = self.resnet_model(pixel_values=resnet_pixel_values).pooler_output
        resnet_flatten_pooler_output = resnet_pooler_output.flatten(start_dim=1)

        longformer_pooler_output = self.longformer_model(
            input_ids=longformer_input_ids,
            attention_mask=longformer_attention_mask,
            global_attention_mask=longformer_global_attention_mask,
        ).pooler_output


        concated_outputs = torch.concat([resnet_flatten_pooler_output, longformer_pooler_output], dim=1)
        logits = torch.relu(self.linear(concated_outputs))
        logits = self.fc(logits)
        return logits

In [4]:
model = MalwareDetectionModel()
model.load_state_dict(torch.load(MODEL_DIR / 'pytorch_model.bin', map_location='cpu'))

Some weights of LongformerModel were not initialized from the model checkpoint at kazzand/ru-longformer-tiny-16384 and are newly initialized: ['longformer.pooler.dense.bias', 'longformer.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [5]:
resnet_config = model.resnet_model.config
longformer_config = model.longformer_model.config

In [6]:
class MalwareDetectionConfig(PretrainedConfig):
    ...

In [7]:
mdconfig = MalwareDetectionConfig()

In [8]:
mdconfig.hidden_size = 312
mdconfig.id2label = {
    0: 'benign',
    1: 'malware',
}
mdconfig.label2id = {
    'benign': 0,
    'malware': 1,
}
mdconfig.num_classes = 2

In [9]:
mdconfig.resnet_config = resnet_config.to_dict()
mdconfig.longformer_config = longformer_config.to_dict()

In [10]:
class MalwareDetectionModel(PreTrainedModel):
    config_class = MalwareDetectionConfig
    
    def __init__(self, config: MalwareDetectionConfig) -> None:
        super(MalwareDetectionModel, self).__init__(config)
        self.config = config

        resnet_config = ResNetConfig.from_dict(config.resnet_config)
        self.resnet_model = ResNetModel(resnet_config)

        longformer_config = LongformerConfig.from_dict(config.longformer_config)
        self.longformer_model = LongformerModel(longformer_config)

        in_features = self.resnet_model.config.hidden_sizes[-1] + self.longformer_model.config.hidden_size

        self.linear = torch.nn.Linear(in_features, self.config.hidden_size)
        self.fc = torch.nn.Linear(self.config.hidden_size, self.config.num_classes)
        self.init_weights()

    def forward(
            self,
            longformer_input_ids,
            longformer_attention_mask,
            longformer_global_attention_mask,
            resnet_pixel_values,
    ):
        resnet_pooler_output = self.resnet_model(pixel_values=resnet_pixel_values).pooler_output
        resnet_flatten_pooler_output = resnet_pooler_output.flatten(start_dim=1)

        longformer_pooler_output = self.longformer_model(
            input_ids=longformer_input_ids,
            attention_mask=longformer_attention_mask,
            global_attention_mask=longformer_global_attention_mask,
        ).pooler_output

        concated_outputs = torch.concat([resnet_flatten_pooler_output, longformer_pooler_output], dim=1)
        logits = torch.relu(self.linear(concated_outputs))
        logits = self.fc(logits)
        return logits

In [11]:
model = MalwareDetectionModel(mdconfig)

In [12]:
model.load_state_dict(torch.load(MODEL_DIR / 'pytorch_model.bin', map_location='cpu'))

<All keys matched successfully>

In [13]:
model.save_pretrained(MODEL_DIR)
mdconfig.save_pretrained(MODEL_DIR)

In [15]:
tokenizer = LongformerTokenizerFast.from_pretrained('kazzand/ru-longformer-tiny-16384')
tokenizer.save_pretrained(MODEL_DIR)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


('/Users/vasilyperekhrest/PycharmProjects/malware-detection-system/model/tokenizer_config.json',
 '/Users/vasilyperekhrest/PycharmProjects/malware-detection-system/model/special_tokens_map.json',
 '/Users/vasilyperekhrest/PycharmProjects/malware-detection-system/model/vocab.txt',
 '/Users/vasilyperekhrest/PycharmProjects/malware-detection-system/model/added_tokens.json',
 '/Users/vasilyperekhrest/PycharmProjects/malware-detection-system/model/tokenizer.json')

In [17]:
preprocessor_config = {
    "crop_pct": 0.875,
    "do_normalize": True,
    "do_rescale": True,
    "do_resize": True,
    "feature_extractor_type": "ConvNextFeatureExtractor",
    "image_mean": 0.356,
    "image_processor_type": "ConvNextImageProcessor",
    "image_std": 0.332,
    "resample": 3,
    "rescale_factor": 0.00392156862745098,
    "size": {
        "shortest_edge": 224
    }
}
image_processor = ConvNextImageProcessor(preprocessor_config)
image_processor.save_pretrained(MODEL_DIR)

['/Users/vasilyperekhrest/PycharmProjects/malware-detection-system/model/preprocessor_config.json']