In [1]:
from sklearn.metrics import precision_score, recall_score, f1_score, roc_curve, roc_auc_score
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import transforms
import numpy as np
from pathlib import Path
from torchsummary import summary
torch.set_printoptions(sci_mode=False)

In [2]:
_N = 32
batch_size = 128
n_classes = 43
torch.set_float32_matmul_precision('high')
mean = 0.3211
std = 0.2230
epochs = 100
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
from dataset import get_data
PATH,LABELS,normalise,GROUPS,GROUP_NAMES,n_classes,train_loader,test_loader = get_data(_N,batch_size)

In [4]:
def to_image(img): # convert to unormalized form for viewing
    return (img * std + mean).permute(1,2,0).numpy()

In [5]:
groups_to_keep = {2,3,5}
{old : new for new,old in enumerate(groups_to_keep)}

{2: 0, 3: 1, 5: 2}

In [6]:
GROUP_MATRIX = [[int(group == i) for group in GROUPS] for i in range(3)]

print(",\n".join(map(str,GROUP_MATRIX)))

GROUP_NAMES

[0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0],
[1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0]


['Unique Signs', 'Derestriction Signs', 'Other Prohibitory Signs']

In [7]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        x = torch.zeros((batch_size,1,_N,_N))

        self.activation = torch.nn.functional.relu

        self.pool = torch.nn.AvgPool2d(2,2)

        self.conv1 = torch.nn.Conv2d(1,6,5)
        x = self.pool(self.activation(self.conv1(x)))
        
        self.conv2 = torch.nn.Conv2d(x.shape[1],16,5)
        x = self.pool(self.activation(self.conv2(x)))
        x = torch.flatten(x, start_dim=1)
        self.dense1 = torch.nn.Linear(x.shape[1],120)
        x = self.activation(self.dense1(x))
        self.dense2 = torch.nn.Linear(x.shape[1],80)
        x = self.activation(self.dense2(x))
        self.final = torch.nn.Linear(x.shape[1],n_classes)

    def forward(self,x):
        x = self.pool(self.activation(self.conv1(x)))
        x = self.pool(self.activation(self.conv2(x)))
        
        x = torch.flatten(x, start_dim=1)

        x = self.activation(self.dense1(x))
        x = self.activation(self.dense2(x))
        x = self.final(x)

        return x
        
model = Model().to(device)
summary(model,(1,32,32),batch_size)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [128, 6, 28, 28]             156
         AvgPool2d-2           [128, 6, 14, 14]               0
            Conv2d-3          [128, 16, 10, 10]           2,416
         AvgPool2d-4            [128, 16, 5, 5]               0
            Linear-5                 [128, 120]          48,120
            Linear-6                  [128, 80]           9,680
            Linear-7                  [128, 12]             972
Total params: 61,344
Trainable params: 61,344
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.50
Forward/backward pass size (MB): 7.90
Params size (MB): 0.23
Estimated Total Size (MB): 8.64
----------------------------------------------------------------


In [9]:
in_group = [[5,6,7,10],[0,1,4,11],[2,3,8,9]]
out_of_group = [[0,1,2,3,4,8,9,11],[2,3,5,6,7,8,9,10],[0,1,4,5,6,7,10,11]]
GROUP_MATRIX = torch.tensor([
    [0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0],
    [1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1],
    [0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0]
    ])

def constraint_accuracy(preds,labels): # CA calculated for a batch
    total = 0
    for pred,label in zip(preds,labels):
        group = GROUPS[label]
        unsatisfied = False
        for label_in in in_group[group]:
            for label_out in out_of_group[group]:
                if pred[label_out] > pred[label_in]:
                    unsatisfied = True
                    break
            if unsatisfied:
                break
        if not unsatisfied:
            total += 1
    return total / len(labels)

In [21]:
def constraint_train(model,loss_func,optimiser,epochs):
    early = EarlyStopping(5)
    for epoch in range(epochs):
        ca = 0
        total_loss = 0
        for images, labels in train_loader:
            groups = labels_to_groups(labels).to(device)
            images = images.to(device)


            optimiser.zero_grad()
            preds = model(images)

            loss = loss_func(preds, groups)

            loss.backward()
            optimiser.step()
            total_loss += loss.item()
            ca += constraint_accuracy(preds,labels)
        
        validation_loss = constraint_validation_loss(model,loss_func)
        print(f"Epoch {epoch+1} --- Training Loss {total_loss / len(train_loader):.3f} --- Validation Loss {validation_loss:.3f} --- Constraint Accuracy {ca / len(train_loader):.3f}")
        if early.should_stop_early(validation_loss):
            print("Early stopping")
            break

def constraint_validation_loss(model,loss_func):
    loss = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            groups = labels_to_groups(labels).to(device)
            preds = model(images)
            loss += loss_func(preds,groups).item()
    return loss / len(test_loader)

def labels_to_groups(labels):
    # takes a batch of labels
    # returns one hot encoding of group stuff...
    batch_size = len(labels)
    out = torch.zeros((batch_size,12))
    for i,label in enumerate(labels):
        out[i] = GROUP_MATRIX[GROUPS[label]]
    return out

In [47]:

def pgd_attack(model,images,groups,loss_func,epsilon,iterations,decay_rate,learning_rate,momentum_decay):
    adversarial = denormalise(images)
    lower_bound = adversarial - epsilon
    upper_bound = adversarial + epsilon
    
    decay = torch.logspace(decay_rate,1,iterations,2)
    decay = decay / decay[0]
    decay *= learning_rate
    
    momentum = torch.zeros(images.shape).to(device)

    for alpha in decay:
        normalised = normalise(adversarial)
        normalised.requires_grad = True
        model.zero_grad()
        pred = model(normalised)
        loss = loss_func(pred,groups)

        loss.backward()

        perturbations = torch.sign(normalised.grad.data)
        
        adversarial += (perturbations + momentum) * alpha

        momentum = momentum * momentum_decay + (1 - momentum_decay) * perturbations

        adversarial = torch.clip(adversarial,lower_bound, upper_bound)
        adversarial = torch.clip(adversarial,0,1)

    return normalise(adversarial)



normalise.to(device)

@torch.compile
def denormalise(images):
    return images * std + mean


In [56]:
def evaluate(model,epsilon):
    predictions = []
    labels = []
    cs = 0

    for images,lbls in test_loader:
        cs += constraint_security(images,lbls,model,epsilon,loss_func)

        labels.extend(lbls.numpy())
        
        images = images.to(device)
        preds = model(images)
        predictions.extend(preds.cpu().detach().numpy())

        
            
    cs /= len(test_loader)
            
    labels = np.array(labels)
    predictions = np.array(predictions)
    pred_class = np.argmax(predictions,axis=1)
    precision = precision_score(labels, pred_class, average="weighted")
    recall = recall_score(labels, pred_class, average="weighted")
    f1 = f1_score(labels, pred_class, average="weighted")
    print(f'Precision (macro): {precision:.4f}')
    print(f'Recall (macro): {recall:.4f}')
    print(f'F1-score (macro): {f1:.4f}')
    print(f"Constraint Security: {cs:.4f}")

In [58]:
def constraint_security(images,labels,model,epsilon,loss_func):
    decay_rate = 6
    learning_rate = 40 / 255
    momentum_decay = 0.8
    iterations = 40
    groups = labels_to_groups(labels).to(device)

    adv = pgd_attack(model,images.to(device),groups,loss_func,epsilon,iterations,decay_rate,learning_rate,momentum_decay)
    pred = model(adv)
    return constraint_accuracy(pred,labels)

In [59]:
epsilon = 5 / 255
loss_func = torch.nn.BCEWithLogitsLoss().to(device)
model = Model().to(device)
model.load_state_dict(torch.load("models/constraint_only_10_eps.pth"))
evaluate(model,epsilon)

Precision (macro): 0.3155
Recall (macro): 0.2479
F1-score (macro): 0.2554
Constraint Security: 0.8797


In [62]:
model = Model().to(device)
model.load_state_dict(torch.load("models/hybrid_10_eps.pth"))
evaluate(model,epsilon)

Precision (macro): 0.9917
Recall (macro): 0.9915
F1-score (macro): 0.9915
Constraint Security: 0.7061


In [63]:
model = Model().to(device)
model.load_state_dict(torch.load("models/superclass_base.pth"))
evaluate(model,epsilon)

Precision (macro): 0.9878
Recall (macro): 0.9877
F1-score (macro): 0.9877
Constraint Security: 0.0021
