In [1]:
from __future__ import print_function

from collections import namedtuple

import argparse

import time
import os
import csv

import sys

sys.path.append(".")
sys.path.append("..")

import numpy as np

import onnx

import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.utils
import torch.utils.data
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image

from base.backends import TorchBackend
from base.logic import Logic
from base.dl2 import DL2
from base.fuzzy_logics import *

from training.constraints import *
from training.models import *

from training.group_definitions import gtsrb_groups, cifar10_groups

from training.util import *
from training.grad_norm import *
from training.attacks import *

EpochInfoTrain = namedtuple('EpochInfoTrain', 'pred_acc constr_acc constr_sec pred_loss random_loss constr_loss pred_loss_weight constr_loss_weight input_img adv_img random_img')
EpochInfoTest = namedtuple('EpochInfoTest', 'pred_acc constr_acc constr_sec pred_loss random_loss constr_loss input_img adv_img random_img vacuously_true')


In [2]:
def train(model: torch.nn.Module, device: torch.device, train_loader: torch.utils.data.DataLoader, optimizer, oracle: Attack, grad_norm: GradNorm, logic: Logic, constraint: Constraint, with_dl: bool) -> EpochInfoTrain:
    avg_pred_acc, avg_pred_loss = torch.tensor(0., device=device), torch.tensor(0., device=device)
    avg_constr_acc, avg_constr_sec, avg_constr_loss, avg_random_loss = torch.tensor(0., device=device), torch.tensor(0., device=device), torch.tensor(0., device=device), torch.tensor(0., device=device)

    images = { 'input': None, 'random': None, 'adv': None}

    model.train()

    for _, (data, target) in enumerate(train_loader, start=1):
        inputs, labels = data.to(device), target.to(device)

        # forward pass for prediction accuracy
        outputs = model(inputs)
        ce_loss = F.cross_entropy(outputs, labels)
        correct = torch.mean(torch.argmax(outputs, dim=1).eq(labels).float())

        # get random + adversarial samples
        with torch.no_grad():
            random = oracle.uniform_random_sample(inputs)

        adv = oracle.attack(model, inputs, labels, logic, constraint)

        # forward pass for constraint accuracy (constraint satisfaction on random samples)
        with torch.no_grad():
            loss_random, sat_random = constraint.eval(model, inputs, random, labels, logic, reduction='mean')

        # forward pass for constraint security (constraint satisfaction on adversarial samples)
        with maybe(torch.no_grad(), not with_dl):
            loss_adv, sat_adv = constraint.eval(model, inputs, adv, labels, logic, reduction='mean')

        optimizer.zero_grad(set_to_none=True)

        if not with_dl:
            ce_loss.backward()
            optimizer.step()
        else:
            grad_norm.balance(ce_loss, loss_adv)

        avg_pred_acc += correct
        avg_pred_loss += ce_loss
        avg_constr_acc += sat_random
        avg_constr_sec += sat_adv
        avg_constr_loss += loss_adv
        avg_random_loss += loss_random

        # save one original image, random sample, and adversarial sample image (for debugging, inspecting attacks)
        i = np.random.randint(0, inputs.size(0) - 1)
        images['input'], images['random'], images['adv'] = inputs[i], random[i], adv[i]

    if with_dl:
        grad_norm.renormalise()

    return EpochInfoTrain(
        pred_acc=avg_pred_acc.item() / len(train_loader),
        constr_acc=avg_constr_acc.item() / len(train_loader),
        constr_sec=avg_constr_sec.item() / len(train_loader),
        pred_loss=avg_pred_loss.item() / len(train_loader),
        random_loss=avg_random_loss.item() / len(train_loader),
        constr_loss=avg_constr_loss.item() / len(train_loader),
        pred_loss_weight=grad_norm.weights[0].item(),
        constr_loss_weight=grad_norm.weights[1].item(),
        input_img=images['input'],
        adv_img=images['adv'],
        random_img=images['random']
    )

In [3]:
def test(model: torch.nn.Module, device: torch.device, test_loader: torch.utils.data.DataLoader, oracle: Attack, logic: Logic, constraint: Constraint) -> EpochInfoTest:
    correct, constr_acc, constr_sec = torch.tensor(0., device=device), torch.tensor(0., device=device), torch.tensor(0., device=device)
    avg_pred_loss, avg_constr_loss, avg_random_loss = torch.tensor(0., device=device), torch.tensor(0., device=device), torch.tensor(0., device=device)

    record_vacuously_true = isinstance(constraint, EvenOddConstraint) or isinstance(constraint, ClassSimilarityConstraint)

    if record_vacuously_true:
        vacuously_true = torch.zeros(2 if isinstance(constraint, EvenOddConstraint) else 10, device=device)

    total_samples = 0

    images = { 'input': None, 'random': None, 'adv': None}

    model.eval()

    for _, (data, target) in enumerate(test_loader, start=1):
        inputs, labels = data.to(device), target.to(device)
        total_samples += inputs.size(0)

        with torch.no_grad():
            # forward pass for prediction accuracy
            outputs = model(inputs)
            avg_pred_loss += F.cross_entropy(outputs, labels, reduction='sum')
            pred = outputs.max(dim=1, keepdim=True)[1]
            correct += pred.eq(labels.view_as(pred)).sum()

            # get random samples (no grad)
            random = oracle.uniform_random_sample(inputs)

        # get adversarial samples (requires grad)
        adv = oracle.attack(model, inputs, labels, logic, constraint)

        # forward passes for constraint accuracy (constraint satisfaction on random samples) + constraint security (constraint satisfaction on adversarial samples)
        with torch.no_grad():
            loss_random, sat_random = constraint.eval(model, inputs, random, labels, logic, reduction='sum')
            loss_adv, sat_adv = constraint.eval(model, inputs, adv, labels, logic, reduction='sum')

            if record_vacuously_true:
                vacuously_true += constraint.get_vacuously_true(model, adv)

            constr_acc += sat_random
            constr_sec += sat_adv

            avg_random_loss += loss_random
            avg_constr_loss += loss_adv

        # save one original image, random sample, and adversarial sample image (for debugging, inspecting attacks)
        i = np.random.randint(0, inputs.size(0) - 1)
        images['input'], images['random'], images['adv'] = inputs[i], random[i], adv[i]

    return EpochInfoTest(
        pred_acc=correct.item() / total_samples, 
        constr_acc=constr_acc.item() / total_samples,
        constr_sec=constr_sec.item() / total_samples,
        pred_loss=avg_pred_loss.item() / total_samples,
        random_loss=avg_random_loss.item() / total_samples,
        constr_loss=avg_constr_loss.item() / total_samples,
        input_img=images['input'],
        adv_img=images['adv'],
        random_img=images['random'],
        vacuously_true=(vacuously_true / total_samples) if record_vacuously_true else -1.
    )

In [4]:
backend = TorchBackend()

logics: list[Logic] = [
    DL2(backend),
    GoedelFuzzyLogic(backend),
    KleeneDienesFuzzyLogic(backend),
    LukasiewiczFuzzyLogic(backend),
    ReichenbachFuzzyLogic(backend),
    GoguenFuzzyLogic(backend),
    ReichenbachSigmoidalFuzzyLogic(backend),
    YagerFuzzyLogic(backend)
]

In [5]:
GROUPS = [3,3,3,3,3,3,2,3,3,4,4,1,0,0,0,4,4,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,2,5,5,5,5,5,5,5,5,2,2]
GROUP_NAMES = ["Unique Signs","Danger Signs","Derestriction Signs","Speed Limit Signs","Other Prohibitory Signs","Mandatory Signs"]

In [6]:
_N = 32
batch_size = 128
n_classes = 43
epochs = 30
kwargs = {"batch_size": batch_size}

In [7]:

if torch.cuda.is_available():
    device = torch.device('cuda')
    torch.backends.cudnn.deterministic = False #True
    torch.backends.cudnn.benchmark = True #False

    kwargs.update({ 'num_workers': 4, 'pin_memory': True })
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

In [8]:
mean = 0.3211
std = 0.2230

def to_image(img): # convert to unormalized form for viewing

    return (img * std + mean).permute(1,2,0).numpy()

normalise = transforms.Normalize(mean, std)
transform = transforms.Compose([
    transforms.Resize((_N,_N)),
    transforms.ToTensor(),
    transforms.Grayscale(),
    normalise
])

In [9]:
dataset = torchvision.datasets.ImageFolder(root="/home/rob/code/Project/dataset/GTSRB/Training", transform=transform)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.8,0.2])
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=batch_size, 
                                           num_workers=6,
                                           shuffle=True,
                                           drop_last=True
                                           )
test_loader = torch.utils.data.DataLoader(test_dataset, 
                                          batch_size=batch_size, 
                                          num_workers=6,
                                          shuffle=False,
                                          drop_last=True
                                          )

In [10]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        x = torch.zeros((64,1,_N,_N))

        self.activation = torch.nn.functional.relu

        self.pool = torch.nn.AvgPool2d(2,2)

        self.conv1 = torch.nn.Conv2d(1,6,5)
        x = self.pool(self.activation(self.conv1(x)))
        
        self.conv2 = torch.nn.Conv2d(x.shape[1],16,5)
        x = self.pool(self.activation(self.conv2(x)))
        x = torch.flatten(x, start_dim=1)
        self.dense1 = torch.nn.Linear(x.shape[1],128)
        x = self.activation(self.dense1(x))
        self.dense2 = torch.nn.Linear(x.shape[1],64)
        x = self.activation(self.dense2(x))
        self.final = torch.nn.Linear(x.shape[1],n_classes)


    def forward(self,x):
        x = self.pool(self.activation(self.conv1(x)))
        x = self.pool(self.activation(self.conv2(x)))
        
        x = torch.flatten(x, start_dim=1)

        x = self.activation(self.dense1(x))
        x = self.activation(self.dense2(x))
        x = self.final(x)

        return x

In [11]:

logic = logics[0] # need some logic loss for oracle even for baseline
is_baseline = True

epsilon = 16 / 255
delta = 0.02

constraint = GroupConstraint(device, epsilon, delta, gtsrb_groups)
print(f'constraint.eps={constraint.eps}')

constraint.eps=0.062745101749897


In [12]:
model = Model().to(device)

pgd_iterations = 20
pgd_restarts = 10
oracle = APGD(device, pgd_iterations, pgd_restarts, mean, std, constraint.eps)
oracle_test = APGD(device, pgd_iterations * 2, pgd_restarts, mean, std, constraint.eps)

optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0001)

grad_norm = GradNorm(model, device, optimizer, lr=0.001, alpha=0.12, initial_dl_weight=1.0)


In [13]:
folder = "abcd"
folder_name = "abcd"
file_name = "abcd/test"
report_file_name = f'{file_name}.csv'
model_file_name = f'{file_name}.onnx'
os.makedirs(folder_name, exist_ok=True)
def save_imgs(info: EpochInfoTrain | EpochInfoTest, epoch):
    return
print(f'using device {device}')
print(f'#model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
with_extra_info = False


using device cuda
#model parameters: 64951


In [15]:
with open(report_file_name, 'w', buffering=1, newline='') as csvfile:
    writer = csv.writer(csvfile, delimiter=',')
    csvfile.write(f'#{sys.argv}\n')
    writer.writerow(['Epoch', 'Train-P-Loss', 'Train-R-Loss', 'Train-C-Loss', 'Train-P-Loss-Weight', 'Train-C-Loss-Weight', 'Train-P-Acc', 'Train-C-Acc', 'Train-C-Sec', 'Test-P-Acc', 'Test-C-Acc', 'Test-C-Sec', 'Train-Time', 'Test-Time'])

    for epoch in range(0, epochs + 1):
        start = time.time()

        if epoch > 0:
            with_dl = True # (epoch > args.delay) and (not is_baseline)
            train_info = train(model, device, train_loader, optimizer, oracle, grad_norm, logic, constraint, with_dl)
            train_time = time.time() - start

            save_imgs(train_info, epoch)

            print(f'Epoch {epoch}/{epochs} \t TRAIN \t P-Acc: {train_info.pred_acc:.2f} \t C-Acc: {train_info.constr_acc:.2f}\t C-Sec: {train_info.constr_sec:.2f}\t P-Loss: {train_info.pred_loss:.2f}\t R-Loss: {train_info.random_loss:.2f}\t DL-Loss: {train_info.constr_loss:.2f}\t Time (Train) [s]: {train_time:.1f}')
        else:
            train_info = EpochInfoTrain(0., 0., 0., 0., 0., 0., 1., 1., None, None, None)
            train_time = 0.

        test_info = test(model, device, test_loader, oracle_test, logic, constraint)
        test_time = time.time() - start - train_time

        save_imgs(test_info, epoch)

        writer.writerow([epoch, \
                            train_info.pred_loss, train_info.random_loss, train_info.constr_loss, train_info.pred_loss_weight, train_info.constr_loss_weight, train_info.pred_acc, train_info.constr_acc, train_info.constr_sec, \
                            test_info.pred_acc, test_info.constr_acc, test_info.constr_sec, \
                            train_time, test_time] \
                        + ([v.item() for v in test_info.vacuously_true] if with_extra_info else []))

        if with_extra_info:
            print(f'impl vacuously true=[{" ".join([f"{x:.2f}" for x in test_info.vacuously_true])}]')

        print(f'Epoch {epoch}/{epochs} \t TRAIN \t P-Acc: {train_info.pred_acc:.2f} \t C-Acc: {train_info.constr_acc:.2f}\t C-Sec: {train_info.constr_sec:.2f}\t P-Loss: {train_info.pred_loss:.2f}\t R-Loss: {train_info.random_loss:.2f}\t DL-Loss: {train_info.constr_loss:.2f}\t Time (Train) [s]: {train_time:.1f}')
        print(f'===')

Epoch 0/30 	 TRAIN 	 P-Acc: 0.00 	 C-Acc: 0.00	 C-Sec: 0.00	 P-Loss: 0.00	 R-Loss: 0.00	 DL-Loss: 0.00	 Time (Train) [s]: 0.0
===
GradNorm weights=tensor([0.9518, 1.0482])
Epoch 1/30 	 TRAIN 	 P-Acc: 0.33 	 C-Acc: 0.06	 C-Sec: 0.05	 P-Loss: 2.45	 R-Loss: 0.46	 DL-Loss: 0.61	 Time (Train) [s]: 295.4
Epoch 1/30 	 TRAIN 	 P-Acc: 0.33 	 C-Acc: 0.06	 C-Sec: 0.05	 P-Loss: 2.45	 R-Loss: 0.46	 DL-Loss: 0.61	 Time (Train) [s]: 295.4
===
GradNorm weights=tensor([0.9510, 1.0490])
Epoch 2/30 	 TRAIN 	 P-Acc: 0.77 	 C-Acc: 0.50	 C-Sec: 0.41	 P-Loss: 0.78	 R-Loss: 0.11	 DL-Loss: 0.31	 Time (Train) [s]: 366.3
Epoch 2/30 	 TRAIN 	 P-Acc: 0.77 	 C-Acc: 0.50	 C-Sec: 0.41	 P-Loss: 0.78	 R-Loss: 0.11	 DL-Loss: 0.31	 Time (Train) [s]: 366.3
===
GradNorm weights=tensor([0.9635, 1.0365])
Epoch 3/30 	 TRAIN 	 P-Acc: 0.86 	 C-Acc: 0.71	 C-Sec: 0.59	 P-Loss: 0.50	 R-Loss: 0.06	 DL-Loss: 0.18	 Time (Train) [s]: 311.9
Epoch 3/30 	 TRAIN 	 P-Acc: 0.86 	 C-Acc: 0.71	 C-Sec: 0.59	 P-Loss: 0.50	 R-Loss: 0.06	 DL-Loss

KeyboardInterrupt: 