In [1]:
import numpy as np
import torchvision
import torch
import torch.nn as nn
from tqdm import tqdm

from torchmetrics import Accuracy
from torch.utils.data import DataLoader, Dataset
from cifar10_models.resnet import resnet18
from sklearn.model_selection import train_test_split

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

  warn(f"Failed to load image Python extension: {e}")


In [9]:
cifar_gaussian_noise = np.load("CIFAR-10-C/gaussian_noise.npy")
cifar_labels = np.load("CIFAR-10-C/labels.npy")

transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
)

cifar_f = torchvision.datasets.CIFAR10(root='./data', train=True,
                                               download=True, transform=transform)

# cifar_f_test = torchvision.datasets.CIFAR10(root='./data', train=False,
#                                                download=False, transform=transform)
cifar_gn = [0]*len(cifar_gaussian_noise)
for i in range(len(cifar_gaussian_noise)):
    cifar_gn[i] = transform(cifar_gaussian_noise[i])
for i in range(len(cifar_f.targets)):
    cifar_f.targets[i] = 9 - cifar_f.targets[i]
    
cifar_gnx_train, cifar_gnx_test, cifar_gny_train, cifar_gny_test = train_test_split(cifar_gn, 
                                                                                    cifar_labels, test_size=0.30, 
                                                                                    random_state=42)

cifar_fx_train, cifar_fx_test, cifar_fy_train, cifar_fy_test = train_test_split(cifar_f.data, 
                                                                                    cifar_f.targets, test_size=0.30, 
                                                                                    random_state=42)

AttributeError: 'CIFAR10' object has no attribute 'to'

In [7]:
class MultiTaskDataset:
    def __init__(self, t0_x, t1_x, t0_y, t1_y):
        self.t0_x = t0_x
        self.t1_x = t1_x
        self.t0_y = t0_y
        self.t1_y = t1_y
        
    def __getitem__(self, idx):
        return self.t0_x[idx], self.t1_x[idx], self.t0_y[idx], self.t1_y[idx]
            
    def __len__(self):
        return len(self.t0_x)

In [8]:
X_train = MultiTaskDataset(cifar_gnx_train, cifar_fx_train, cifar_gny_train, cifar_fy_train)
X_test = MultiTaskDataset(cifar_gnx_test, cifar_fx_test, cifar_gny_test, cifar_fy_test)

trainloader = DataLoader(X_train, batch_size=256, shuffle=True)
testloader = DataLoader(X_test, batch_size=256, shuffle=False)

AttributeError: 'MultiTaskDataset' object has no attribute 'to'

In [183]:
class MTL_resnet(nn.Module):
    def __init__(self, shared=True):
        super().__init__()
        
        self.resnet_backbone = resnet18(pretrained=True)
        self.resnet_backbone = nn.Sequential(*list(self.resnet_backbone.children())[:-1])
        
        self.t0_head = nn.Linear(512, 10)
        self.t1_head = nn.Linear(512, 10)
    
    def forward(self, t0_x, t1_x):
        t0_resnet = self.resnet_backbone(t0_x)
        t1_resnet = self.resnet_backbone(t1_x)
        
        b, h, _, _ = t0_resnet.shape
        t0_logits = self.t0_head(t0_resnet.view((b, h)))
        t1_logits = self.t1_head(t1_resnet.view((b, h)))
        
        return t0_logits, t1_logits

In [11]:
accuracy = Accuracy()

def train_mtl_rgn(
    *, 
    train_loader, 
    test_loader, 
    lr=1e-3, 
    epochs=5, 
):
    model = MTL_resnet()
    t0_model = MTL_resnet()
    t1_model = MTL_resnet()
    
    model.to(DEVICE)
    t0_model.to(DEVICE)
    t1_model.to(DEVICE)
    
    linear_lr = []
    shared_lr = []
    for k, v in model.named_parameters():
        if "backbone" in k and "weight" in k:
            linear_lr.append(v)
        else:
            shared_lr.append(v)
#     parameters = [{"params": linear_lr}, {"params": shared_lr}]
    parameters = [{"params": linear_lr}, {"params": shared_lr}]
    optimizer = torch.optim.Adam(parameters, lr=lr)
    t0_optimizer = torch.optim.Adam(optimizer.param_groups, lr = lr)
    t1_optimizer = torch.optim.Adam(optimizer.param_groups, lr = lr)
    
    train_metrics = {'loss/total': [], 
                     'loss/t0': [], 
                     'loss/t1': [],
                     'acc/t0': [],
                     'acc/t1': [],
                    }
    val_metrics = {'loss/total': [], 
                     'loss/t0': [], 
                     'loss/t1': [],
                     'acc/t0': [],
                     'acc/t1': [],
                    }

    for epoch in range(epochs):
        model.train()
        last_it = len(train_loader)
        for i, batch in tqdm(enumerate(train_loader), total=len(train_loader)):    
            t0_x, t1_x, t0_labels, t1_labels = batch
            t0_x, t1_x, t0_labels, t1_labels = t0_x.to(DEVICE), t1_x.to(DEVICE), t0_labels.to(DEVICE), t1_labels.to(DEVICE)

            # t0_x = torch.permute(t0_x, (0, 3, 1, 2)).float()
            t1_x = torch.permute(t1_x, (0, 3, 1, 2)).float()

            t0_logits, t1_logits = model.forward(t0_x, t1_x)
            
#             t0_logits = t0_logits.view(-1,)
#             t1_logits = t1_logits.view(-1,)
            
            t0_loss = nn.functional.cross_entropy(t0_logits, t0_labels)
            t1_loss = nn.functional.cross_entropy(t1_logits, t1_labels)
            
            total_loss = (t0_loss + t1_loss)/2
            
            t0_loss.backward(retain_graph=True)
            t0rgn = 1
            t1rgn = 1
            if i == last_it:
                t0_model.load_state_dict(model.state_dict())
                t0_optimizer.load_state_dict(optimizer.state_dict())
                t1_model.load_state_dict(model.state_dict())
                t1_optimizer.load_state_dict(optimizer.state_dict())
            if i >= last_it:
                for g, h in zip(t0_optimizer.param_groups, t1_optimizer.param_groups):
                    t0grad = torch.autograd.grad(t0_loss, g["params"][0], allow_unused=True, retain_graph=True)
                    t1grad = torch.autograd.grad(t1_loss, h["params"][0], allow_unused=True, retain_graph=True)
                    
                    if t0grad[0] is not None:
                        t0rgn = torch.linalg.norm(t0grad[0]/torch.linalg.norm(g["params"][0]))
                    if t1grad[0] is not None:
                        t1rgn = torch.linalg.norm(t1grad[0]/torch.linalg.norm(h["params"][0]))
                    t0rgn, t1rgn = nn.functional.softmax(torch.Tensor([t0rgn, t1rgn]))
                    g["lr"] = t0rgn*g["lr"]
                    h["lr"] = t1rgn*h["lr"]
                    
                    t0rgn = 1
                    t1rgn = 1
                    
                    
                t0_optimizer.step()
                t1_optimizer.step()
                
                t0_optimizer.zero_grad()
                t1_optimizer.zero_grad()
                
                # loss metrics
                train_metrics["loss/total"].append(total_loss.item())
                train_metrics["loss/t0"].append(t0_loss.item())
                train_metrics["loss/t1"].append(t1_loss.item())

                # accuracy metrics
                t0_lgt, _ = t0_model.forward(t0_x, t1_x)
                _, t1_lgt = t1_model.forward(t0_x, t1_x)
                t0_lgt, t1_lgt = t0_lgt.clone().detach().cpu(), t1_lgt.clone().detach().cpu()
                t0_lbl = t0_labels.clone().detach().cpu()
                t1_lbl = t1_labels.clone().detach().cpu()
#                 t0_logits, t1_logits = model.forward(im)
                train_metrics['acc/t0'].append(accuracy(t0_lgt, t0_lbl))
                train_metrics['acc/t1'].append(accuracy(t1_lgt, t1_lbl))
            else:
                optimizer.step()
                optimizer.zero_grad()

                # loss metrics
                train_metrics["loss/total"].append(total_loss.item())
                train_metrics["loss/t0"].append(t0_loss.item())
                train_metrics["loss/t1"].append(t1_loss.item())

                # accuracy metrics
                t0_lgt, t1_lgt = model.forward(t0_x, t1_x)
                t0_lgt, t1_lgt = t0_lgt.clone().detach().cpu(), t1_lgt.clone().detach().cpu()
                t0_lbl = t0_labels.clone().detach().cpu()
                t1_lbl = t1_labels.clone().detach().cpu()
#                 t0_pred = nn.functional.sigmoid(t0_lgt) > 0.5
#                 t1_pred = nn.functional.sigmoid(t1_lgt) > 0.5
                train_metrics['acc/t0'].append(accuracy(t0_lgt, t0_lbl))
                train_metrics['acc/t1'].append(accuracy(t1_lgt, t1_lbl))


        t0_val_loss = 0
        t1_val_loss = 0
        t0_val_acc = 0
        t1_val_acc = 0
        val_count = 0

        model.eval()
        t0_model.eval()
        t1_model.eval()
        for batch in test_loader:
            t0_x, t1_x, t0_labels, t1_labels = batch
            t0_x, t1_x, t0_labels, t1_labels = t0_x.to(DEVICE), t1_x.to(DEVICE), t0_labels.to(DEVICE), t1_labels.to(DEVICE)
#             val_loss = model.loss(im, t0_labels.float(), t1_labels.float())
            # t0_x = torch.permute(t0_x, (0, 3, 1, 2)).float()
            t1_x = torch.permute(t1_x, (0, 3, 1, 2)).float()

            t0_logits, t1_logits = t0_model.forward(t0_x, t1_x)
            
#             t0_logits = t0_logits.view(-1,)
#             t1_logits = t1_logits.view(-1,)
            
            t0_loss = nn.functional.cross_entropy(t0_logits, t0_labels)
            t1_loss = nn.functional.cross_entropy(t1_logits, t1_labels)
            total_loss = (t0_loss + t1_loss)/2
            
            val_count += len(t0_x)
            t0_val_loss += t0_loss.item() * len(t0_x)
            t1_val_loss += t1_loss.item() * len(t1_x)

            # accuracy metrics
#             t0_logits, _ = t0_model.forward(t0_x, t1_x)
#             _, t1_logits = t1_model.forward(t0_x, t1_x)
            t0_lgt, t1_lgt = model.forward(t0_x, t1_x)
            t0_lgt, t1_lgt = t0_lgt.clone().detach().cpu(), t1_lgt.clone().detach().cpu()
            t0_lbl = t0_labels.clone().detach().cpu()
            t1_lbl = t1_labels.clone().detach().cpu()
            t0_val_acc += accuracy(t0_lgt, t0_lbl) * len(t0_x)
            t1_val_acc += accuracy(t1_lgt, t1_lbl) * len(t1_x)

        val_metrics['loss/t0'].append(t0_val_loss / val_count)
        val_metrics['loss/t1'].append(t1_val_loss / val_count)
        val_metrics['acc/t0'].append(t0_val_acc / val_count)
        val_metrics['acc/t1'].append(t1_val_acc / val_count)
    
    return train_metrics, val_metrics

In [12]:
_, shared_metrics = train_mtl_rgn(
    train_loader=trainloader, 
    test_loader=testloader,
    lr=1e-3, 
    epochs=5,
)

NameError: name 'trainloader' is not defined

In [13]:
from robustbench.utils import load_model
from robustbench.data import load_cifar10c

In [14]:
corruptions = ['gaussian_noise']

In [15]:
x_test, y_test = load_cifar10c(n_examples=1000, corruptions=corruptions, severity=5)

Starting download from https://zenodo.org/api/files/a35f793a-6997-4603-a92f-926a6bc0fa60/CIFAR-10-C.tar


44533it [23:17, 31.87it/s]                                                      


Download finished, extracting...
Downloaded and extracted.


In [None]:
model = load_model(model_name, dataset='cifar10', threat_model='Linf')
acc = clean_accuracy(model, x_test, y_test)
print(f'Model: {model_name}, CIFAR-10-C accuracy: {acc:.1%}')