In [31]:
from pathlib import Path

import torch
from transformers import (
    ViTConfig,
    ViTModel,
    AutoModel,
    PreTrainedModel,
    PretrainedConfig,
    LongformerModel,
    LongformerConfig,
    LongformerTokenizerFast,
    AutoImageProcessor,
)

In [7]:
NOTEBOOKS_DIR = Path().resolve()
PROJECT_DIR = NOTEBOOKS_DIR.parent
DATA_DIR = PROJECT_DIR / 'data'
MODEL_DIR = PROJECT_DIR / 'model-2'
MODEL_DIR.mkdir(exist_ok=True)

In [10]:

class ViTLongformerModel(torch.nn.Module):
    def __init__(self) -> None:
        super(ViTLongformerModel, self).__init__()
        cfg = ViTConfig.from_pretrained("WinKawaks/vit-small-patch16-224")
        self.vit_model = ViTModel(cfg, add_pooling_layer=False)
        self.longformer_model = AutoModel.from_pretrained('kazzand/ru-longformer-tiny-16384')
        in_features = self.vit_model.config.hidden_size + self.longformer_model.config.hidden_size
        self.classifier = torch.nn.Sequential(
            torch.nn.Dropout(0.2),
            torch.nn.Linear(in_features, 312),
            torch.nn.BatchNorm1d(312),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(312, 2),
        )

    def forward(
            self,
            longformer_input_ids,
            longformer_attention_mask,
            longformer_global_attention_mask,
            vit_pixel_values,
    ):
        vit_embds = self.vit_model(pixel_values=vit_pixel_values).last_hidden_state[:,0,:]
        longformer_embds = self.longformer_model(
            input_ids=longformer_input_ids,
            attention_mask=longformer_attention_mask,
            global_attention_mask=longformer_global_attention_mask,
        ).last_hidden_state[:,0,:]

        concated_outputs = torch.concat([vit_embds.flatten(start_dim=1), longformer_embds.flatten(start_dim=1)], dim=1)
        logits = self.classifier(concated_outputs)
        return logits

In [11]:
model = ViTLongformerModel()
model.load_state_dict(torch.load(DATA_DIR / 'checkpoints/model-3-epoch.pt', map_location='cpu'))

Some weights of LongformerModel were not initialized from the model checkpoint at kazzand/ru-longformer-tiny-16384 and are newly initialized: ['longformer.pooler.dense.bias', 'longformer.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [14]:
vit_config = model.vit_model.config
longformer_config = model.longformer_model.config

In [15]:
class ViTLongformerConfig(PretrainedConfig):
    ...

In [16]:
config = ViTLongformerConfig()

In [22]:
config.id2label = {
    0: 'benign',
    1: 'malware',
}
config.label2id = {
    'benign': 0,
    'malware': 1,
}
config.hidden_size = 312
config.num_labels = 2

In [23]:
config.vit_config = vit_config.to_dict()
config.longformer_config = longformer_config.to_dict()

In [26]:
class ViTLongformerModel(PreTrainedModel):
    config_class = ViTLongformerConfig
    
    def __init__(self, config: ViTLongformerConfig) -> None:
        super(ViTLongformerModel, self).__init__(config)
        self.config = config

        vit_config = ViTConfig.from_dict(config.vit_config)
        self.vit_model = ViTModel(vit_config, add_pooling_layer=False)

        longformer_config = LongformerConfig.from_dict(config.longformer_config)
        self.longformer_model = LongformerModel(longformer_config)
        
        in_features = self.vit_model.config.hidden_size + self.longformer_model.config.hidden_size
        self.classifier = torch.nn.Sequential(
            torch.nn.Dropout(0.2),
            torch.nn.Linear(in_features, self.config.hidden_size),
            torch.nn.BatchNorm1d(self.config.hidden_size),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(self.config.hidden_size, self.config.num_labels),
        )
        self.init_weights()

    def forward(
            self,
            longformer_input_ids,
            longformer_attention_mask,
            longformer_global_attention_mask,
            vit_pixel_values,
    ):
        vit_embds = self.vit_model(pixel_values=vit_pixel_values).last_hidden_state[:,0,:]
        longformer_embds = self.longformer_model(
            input_ids=longformer_input_ids,
            attention_mask=longformer_attention_mask,
            global_attention_mask=longformer_global_attention_mask,
        ).last_hidden_state[:,0,:]

        concated_outputs = torch.concat([vit_embds.flatten(start_dim=1), longformer_embds.flatten(start_dim=1)], dim=1)
        logits = self.classifier(concated_outputs)
        return logits

In [27]:
model = ViTLongformerModel(config)

In [28]:
model.load_state_dict(torch.load(DATA_DIR / 'checkpoints/model-3-epoch.pt', map_location='cpu'))

<All keys matched successfully>

In [29]:
model.save_pretrained(MODEL_DIR)
config.save_pretrained(MODEL_DIR)

In [30]:
tokenizer = LongformerTokenizerFast.from_pretrained('kazzand/ru-longformer-tiny-16384')
tokenizer.save_pretrained(MODEL_DIR)

('/Users/vasilyperekhrest/PycharmProjects/malware-detection/model-2/tokenizer_config.json',
 '/Users/vasilyperekhrest/PycharmProjects/malware-detection/model-2/special_tokens_map.json',
 '/Users/vasilyperekhrest/PycharmProjects/malware-detection/model-2/vocab.txt',
 '/Users/vasilyperekhrest/PycharmProjects/malware-detection/model-2/added_tokens.json',
 '/Users/vasilyperekhrest/PycharmProjects/malware-detection/model-2/tokenizer.json')

In [32]:

image_processor = AutoImageProcessor.from_pretrained('WinKawaks/vit-small-patch16-224')
image_processor.save_pretrained(MODEL_DIR)

preprocessor_config.json: 100%|██████████| 160/160 [00:00<00:00, 322kB/s]


['/Users/vasilyperekhrest/PycharmProjects/malware-detection/model-2/preprocessor_config.json']