# Import Libraries

In [49]:
import torch
import torch.nn as nn

# Define Safe Model

In [50]:
class SafeModel(nn.Module):
    def __init__(self):
        super(SafeModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(32 * 26 * 26, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.fc(x)
        output = self.softmax(x)
        return output

In [57]:
safe_model = SafeModel()

In [58]:
safe_model_path = "./models/safe_model.pt"
torch.save(safe_model.state_dict(), safe_model_path)

In [56]:
#safe_model = torch.load(safe_model_path, weights_only=False)

# Define Malicious Model

In [59]:
class CustomLayer(nn.Module):
    def __init__(self):
        super(CustomLayer, self).__init__()
        
    def forward(self, inputs):
        return inputs * 2 + 1

In [60]:
class MaliciousModel(nn.Module):
    def __init__(self):
        super(MaliciousModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.relu = nn.ReLU()
        self.custom = CustomLayer()
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(32 * 26 * 26, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.custom(x)
        x = self.flatten(x)
        x = self.fc(x)
        output = self.softmax(x)
        return output

In [61]:
malicious_model = MaliciousModel()

In [62]:
malicious_model_path = "./models/malicious_model.pt"
torch.save(malicious_model.state_dict(), malicious_model_path)

In [63]:
#malicious_model = torch.load(malicious_model_path, weights_only=False)

# Custom Layer Detection

In [80]:
def detect_custom_layer(model):
    standard_layers = dir(nn)

    custom_layers = []
    for module in list(model.modules())[1:]:
        if module.__class__.__name__ not in standard_layers:
            custom_layers.append(module.__class__.__name__)

    return custom_layers

In [81]:
detect_custom_layer(safe_model)

[]

In [82]:
detect_custom_layer(malicious_model)

['CustomLayer']