# Running ModelScan on a Pytorch Model

## Import statements

In [None]:
import torch
import os
from utils.pytorch_sentiment_model import download_model, predict_sentiment
from utils.pickle_codeinjection import PickleInject, get_payload

%env TOKENIZERS_PARALLELISM=false

## Download and save the model

We are going to use a BERT based sentiment analysis PyTorch model (https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment). This safe model will get saved at `./PyTochModels/safe_model.pt`

In [None]:
# Save a model for sentiment analysis
from typing import Final

model_directory: Final[str] = "PyTorchModels"
if not os.path.isdir(model_directory):
    os.mkdir(model_directory)

safe_model_path = os.path.join(model_directory, "safe_model.pt")

download_model(safe_model_path)

## Run the model

Run the safe model to verify that it has been downloaded correctly.

In [None]:
sentiment = predict_sentiment(
    "Stock market was bearish today", torch.load(safe_model_path)
)

## Run ModelScan on the safe model

Now run the Modelscan tool using the modelscan command. Remember that we installed modelscan in our virtualenv. 

**The scan results include information on the files scanned, and any issues if found. For the safe model scanned, modelscan finds no model serialization attacks.**

In [None]:
!modelscan --path PyTorchModels/safe_model.pt

## Model Serialization Attack

Now we inject some malicious code into the safe model and save it in a new model `./PyTorchModels/unsafe_model.pt`. 

The code we are injecting is to read aws secret keys.

In [None]:
command = "system"
malicious_code = """cat ~/.aws/secrets
    """

unsafe_model_path = os.path.join(model_directory, "unsafe_model.pt")

payload = get_payload(command, malicious_code)
torch.save(
    torch.load(safe_model_path),
    f=unsafe_model_path,
    pickle_module=PickleInject([payload]),
)

## Unsafe Model Prediction

The malicious code injected in the unsafe model gets executed when it is loaded. 

You can see in the output that the aws secret keys are displayed. 

Also, the unsafe model predicts the sentiments just as well as safe model i.e. the code injection will not impact the model performance. The unaffected performance of unsafe models makes the ML models an effective attack vector.

In [None]:
predict_sentiment("Stock market was bearish today", torch.load(unsafe_model_path))

## Run ModelScan on the unsafe model

The scan results include information on the files scanned, and any issues if found. In this case, a critical severity level issue is found in the unsafe model scanned.

modelscan also outlines the found operator(s) and module(s) deemed unsafe.

In [None]:
!modelscan --path  ./PyTorchModels/unsafe_model.pt

## Change the reporting format of output

This will save the scan results in the file: pytorch-model-scan-results.json

In [None]:
!modelscan --path  ./PyTorchModels/unsafe_model.pt -r json -o pytorch-model-scan-results.json