In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Subset
from torchvision.models import wide_resnet50_2

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
valid_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Define known and outlier classes
known_classes = [0, 1, 2, 3, 4]
outlier_classes = [5, 6, 7, 8, 9]


# Filter the datasets to include only known classes
train_indices = [i for i, label in enumerate(train_dataset.targets) if label in known_classes]
valid_indices = [i for i, label in enumerate(valid_dataset.targets)]

train_dataset = Subset(train_dataset, train_indices)
valid_dataset = Subset(valid_dataset, valid_indices)

# DataLoaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=8)


Files already downloaded and verified
Files already downloaded and verified


In [None]:
# Load pre-trained WideResNet model
model = wide_resnet50_2(pretrained=True)
model = nn.Sequential(*list(model.children())[:-1])  # Remove the last layer
model.eval()
model.to(device)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=F

In [None]:
# Function to extract features
def extract_features(loader):
    features = []
    labels = []
    with torch.no_grad():
        for data, target in loader:
            data = data.to(device)
            feature = model(data).squeeze()
            features.append(feature.cpu().numpy())
            labels.append(target.numpy())
    features = np.concatenate(features, axis=0)
    labels = np.concatenate(labels, axis=0)
    return features, labels

# Extract features
train_features, train_labels = extract_features(train_loader)
valid_features, valid_labels = extract_features(valid_loader)

mask = np.isin(valid_labels, known_classes)


In [None]:
# Define the 2-layer neural network
class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.sigmoid(x)
        x = self.fc2(x)
        return x

In [None]:
def hook(module, input, output):
    setattr(module, 'input_tensor', input[0])
    input_copy = input[0].clone()
    output_copy = output.clone()


def setatr(model):
  for name, module in model.named_modules():
      if isinstance(module, torch.nn.Linear):
          module.register_forward_hook(hook)

In [None]:
# Function to compute projection matrix
def compute_proj(W):
    WT = W.t()
    WWt = torch.matmul(W, WT)
    inverse_WWt = torch.inverse(WWt)
    result = torch.matmul(torch.matmul(WT, inverse_WWt), W)
    return result

In [None]:
# Evaluate model
def evaluate_model(model, features, labels, device):
    model.eval()
    corrects = 0
    total_samples = 0
    with torch.no_grad():
        for i in range(0, len(features), batch_size):
            inputs = torch.from_numpy(features[i:i+batch_size]).float().to(device)
            labels_inputs = labels[i:i+batch_size]
            labels_inputs = torch.tensor(labels_inputs).to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            corrects += (preds == labels_inputs).sum().item()
            total_samples += labels_inputs.size(0)
    accuracy = 100.0 * corrects / total_samples
    return accuracy

In [None]:
# Training the network with and without NuSA loss
input_dim = train_features.shape[1]
hidden_dim = 32
output_dim = len(known_classes)
lr = 1e-2
num_epochs = 10
L = 0.1

In [None]:
# Training without NuSA loss
net_no_nusa = Net(input_dim, hidden_dim, output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net_no_nusa.parameters(), lr=lr)

net_no_nusa.to(device)
criterion.to(device)

CrossEntropyLoss()

In [None]:
print("Training without NuSA loss")
for epoch in range(num_epochs):
    net_no_nusa.train()
    for i in range(0, len(train_features), batch_size):
        inputs = torch.from_numpy(train_features[i:i+batch_size]).float().to(device)
        labels = torch.from_numpy(train_labels[i:i+batch_size]).long().to(device)
        optimizer.zero_grad()
        outputs = net_no_nusa(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

Training without NuSA loss


In [None]:
accuracy_no_nusa = evaluate_model(net_no_nusa, valid_features[mask], valid_labels[mask], device)
print("Validation Accuracy without NuSA loss:", accuracy_no_nusa)

Validation Accuracy without NuSA loss: 92.3


In [None]:
# Training with NuSA loss
net_with_nusa = Net(input_dim, hidden_dim, output_dim)
setatr(net_with_nusa)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net_with_nusa.parameters(), lr=lr)

net_with_nusa.to(device)
criterion.to(device)

CrossEntropyLoss()

In [None]:
print("Training with NuSA loss")
for epoch in range(num_epochs):
    net_with_nusa.train()
    for i in range(0, len(train_features), batch_size):
        inputs = torch.from_numpy(train_features[i:i+batch_size]).float().to(device)
        labels = torch.from_numpy(train_labels[i:i+batch_size]).long().to(device)
        optimizer.zero_grad()
        outputs = net_with_nusa(inputs)
        arr = []
        # Compute nusa_loss for this batch
        for name, module in net_with_nusa.named_modules():
            if isinstance(module, torch.nn.Linear):
                  with torch.no_grad():
                      xl = torch.tensor(module.input_tensor.tolist(), device=device)
                      # print(xl.shape)
                      P = compute_proj(module.weight).to(device)
                      # print(module.weight.shape)
                      projection = P.matmul(xl.t())
                      norm_projection = torch.norm(projection)
                      norm_xl = torch.norm(xl)
                      nusa_loss = norm_projection / norm_xl
                      arr.append(nusa_loss)

        nusa_loss_mean = torch.stack(arr).mean()
        loss = criterion(outputs, labels) + L * nusa_loss_mean
        loss.backward()
        optimizer.step()

Training with NuSA loss


In [None]:
accuracy_with_nusa = evaluate_model(net_with_nusa, valid_features[mask], valid_labels[mask], device)
print("Validation Accuracy with NuSA loss:", accuracy_with_nusa)

Validation Accuracy with NuSA loss: 91.48


In [None]:
# Outlier detection
outlier_indicators = []
outlier_class_labels = []
threshold = 0.57

out_nusa = []
nonout_nusa = []

with torch.no_grad():
    for i in range(len(valid_features)):
        nusa = 0.0
        sample = valid_features[i]
        label = valid_labels[i]
        inputs = torch.from_numpy(sample).unsqueeze(0).float().to(device)
        outputs = net_with_nusa(inputs)
        arr = []
        for name, module in net_with_nusa.named_modules():
            if isinstance(module, torch.nn.Linear):
                  with torch.no_grad():
                      xl = torch.tensor(module.input_tensor.tolist(), device=device)
                      # print(xl.shape)
                      P = compute_proj(module.weight).to(device)
                      # print(module.weight.shape)
                      projection = P.matmul(xl.t())
                      norm_projection = torch.norm(projection)
                      norm_xl = torch.norm(xl)
                      nusa_loss = norm_projection / norm_xl
                      arr.append(nusa_loss)

        nusa_mean = torch.stack(arr).mean()
        if label in known_classes:
            nonout_nusa.append(nusa_mean)
            outlier_indicators.append(0)
            outlier_class_labels.append(torch.argmax(outputs).item())
        else:
            out_nusa.append(nusa_mean)
            if nusa_mean > threshold:
                outlier_indicators.append(1)
                outlier_class_labels.append(None)
            else:
                outlier_indicators.append(2)
                outlier_class_labels.append(torch.argmax(outputs).item())

In [None]:
print("NuSa for non-outliers", torch.tensor(nonout_nusa).to(device).mean().item())
print("NuSa for outliers", torch.tensor(out_nusa).to(device).mean().item())

NuSa for non-outliers 0.5825128555297852
NuSa for outliers 0.5539392828941345


In [None]:
print(nonout_nusa[:20])

print(out_nusa[:20])

[tensor(0.6016, device='cuda:0'), tensor(0.6041, device='cuda:0'), tensor(0.6147, device='cuda:0'), tensor(0.5911, device='cuda:0'), tensor(0.6003, device='cuda:0'), tensor(0.5184, device='cuda:0'), tensor(0.5833, device='cuda:0'), tensor(0.5360, device='cuda:0'), tensor(0.5892, device='cuda:0'), tensor(0.6015, device='cuda:0'), tensor(0.5598, device='cuda:0'), tensor(0.5807, device='cuda:0'), tensor(0.5344, device='cuda:0'), tensor(0.6067, device='cuda:0'), tensor(0.5807, device='cuda:0'), tensor(0.5067, device='cuda:0'), tensor(0.5739, device='cuda:0'), tensor(0.5760, device='cuda:0'), tensor(0.5528, device='cuda:0'), tensor(0.6013, device='cuda:0')]
[tensor(0.5896, device='cuda:0'), tensor(0.5565, device='cuda:0'), tensor(0.5057, device='cuda:0'), tensor(0.5567, device='cuda:0'), tensor(0.5738, device='cuda:0'), tensor(0.5383, device='cuda:0'), tensor(0.5678, device='cuda:0'), tensor(0.5860, device='cuda:0'), tensor(0.5562, device='cuda:0'), tensor(0.6228, device='cuda:0'), tensor(0