This is an examplar demonstration using CIFAR-10

In [1]:
import os
import numpy
import torch
import matplotlib.pyplot as plt

from torch import optim
from CIFAR_dataset import *
from models import *

torch.manual_seed(42)
np.random.seed(42)

In [2]:
# Paths of datasets
# Please modify accordingly
data_root = "../data"

batch_size = 1024
zeta_trains = 1-np.array([0.001,0.002,0.005,0.01,0.02,0.05,0.1,0.2])
latent_dim = 64
num_epochs = 200

device = torch.device('cuda:0')

## Please select the task: "origin", "regularization", or "groupDRO"
# task = "origin"
# task = "regularization"
task = "groupDRO"

In [3]:
# Load the raw CIFAR-10 data
trainset = torchvision.datasets.CIFAR10(data_root, train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                             ]))
testset = torchvision.datasets.CIFAR10(data_root, train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                             ]))
# Select the two binary classes from 0~9
dig0,dig1 = 0, 1
where_train = (torch.tensor(trainset.targets) == dig0) + (torch.tensor(trainset.targets) == dig1)
where_test = (torch.tensor(testset.targets) == dig0) + (torch.tensor(testset.targets) == dig1)
index_train = torch.where(where_train)[0]
index_test = torch.where(where_test)[0]

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# The regularizaiton term
def reg(emb, image, label, spurious):

    mu_spur_1 = (emb[(spurious == 1)*(label == 1)].mean(0) + \
                 emb[(spurious == 1)*(label == 0)].mean(0))/2
    mu_spur_0 = (emb[(spurious == 0)*(label == 1)].mean(0) + \
                 emb[(spurious == 0)*(label == 0)].mean(0))/2

    mu_spur = ((mu_spur_1-mu_spur_0)/2).reshape(-1,1)
    sig_spur = ((mu_spur_1 - emb[spurious == 1]).T.mm(
        (mu_spur_1 - emb[spurious == 1])) + \
                    (mu_spur_0 - emb[spurious == 0]).T.mm(
        (mu_spur_0 - emb[spurious == 0])))/(len(trainset_biased.data) - 1)

    mu_info_1 = (emb[(spurious == 1)*(label == 1)].mean(0) + \
                 emb[(spurious == 0)*(label == 1)].mean(0))/2
    mu_info_0 = (emb[(spurious == 1)*(label == 0)].mean(0) + \
                 emb[(spurious == 0)*(label == 0)].mean(0))/2

    mu_info = ((mu_info_1-mu_info_0)/2).reshape(-1,1)

    loss = mu_spur.norm()/mu_info.norm()
    return loss

In [5]:
## Train the models with no regularization or group DRO
if task == "origin":
    for seed in range(10):
        print(f'==========================Seed: {seed}=============================')
        for model_name in ['resnet', 'vgg', 'alexnet']:
            print(f'========================Model: {model_name}=========================')
            for zeta_train in zeta_trains:   
                print(f'========================Zeta: {zeta_train}===========================')
                torch.manual_seed(seed)
                np.random.seed(seed)

                trainset_nonbiased = CIFAR_subset_watermark(trainset, index_train, alpha = .5, beta = .5)
                testset_nonbiased = CIFAR_subset_watermark(testset, index_test, alpha = .5, beta = .5)
                trainset_biased = CIFAR_subset_watermark(trainset, index_train, alpha = zeta_train, beta = zeta_train)

                trainloader = DataLoader(trainset_biased, batch_size=batch_size, num_workers = 8, shuffle=True)
                Loss = []
                if model_name == 'resnet':
                    model = Classifier_resnet(latent_dim).to(device)
                elif model_name == 'vgg':
                    model = Classifier_vgg(latent_dim).to(device)
                elif model_name == 'alexnet':
                    model = Classifier_alexnet(latent_dim).to(device)
                elif model_name == 'efficientnet':
                    model = Classifier_efficientnet(latent_dim).to(device)
                elif model_name == 'mobilenet':
                    model = Classifier_mobilenet(latent_dim).to(device)

                optimizer = optim.SGD(model.parameters(), lr = 1e-3, momentum = 0.9, weight_decay = 5e-4)
                Acc_train = []
                Acc_test = []
                Cls_loss = []
                Reg_loss = []
                criterion = nn.BCELoss()
                model.train()
                for epoch in range(num_epochs):
                    cls_loss = 0.0
                    reg_loss = 0.0
                    acc_train = 0
                    model.train()
                    for data in trainloader:
                        images, label = data
                        images = images.permute(0,3,1,2).to(device)
                        label = label[:,0].unsqueeze(1).to(device)
                        optimizer.zero_grad()
                        emb, pred = model(images)

                        with torch.no_grad():
                            if ((data[1][:,1] == 1)*(data[1][:,0] == 0)).sum().item()*((data[1][:,1] == 0)*(data[1][:,0] == 1)).sum().item() == 0:
                                loss_reg = 0
                            else:
                                loss_reg = reg(emb,images, data[1][:,0], data[1][:,1])

                        loss_cls = criterion(pred, label)
                        loss = loss_cls
                        loss.backward()
                        optimizer.step()
                        cls_loss += loss_cls.item()
                        if loss_reg > 0:
                            reg_loss += loss_reg.item()
                        acc_train += ((pred > 0.5)==label).float().sum().item()

                    acc_train/=len(trainset_biased)

                    model.eval()
                    with torch.no_grad():
                        data = testset_nonbiased.data.permute(0,3,1,2).clone().detach().to(device)
                        label = testset_nonbiased.targets.view(-1,1).clone().detach().to(device)
                        _, pred = model(data)
                        acc_test = ((pred > 0.5)==label).float().mean().item()

                    if cls_loss/len(trainloader)>10:
                        optimizer = optim.SGD(model.parameters(), lr = 1e-4, momentum = 0.9, weight_decay = 5e-4)

                    print('S/M/Z: [%d/%s/%d], Epoch [%d/%d], Loss: [%.4f,%.4f], Acc: [%.4f/%.4f]' % (
                        seed, model_name, zeta_train, epoch+1, num_epochs, cls_loss/len(trainloader),
                        reg_loss/len(trainloader), acc_train, acc_test))

                    Acc_test.append(acc_test)
                    Acc_train.append(acc_train)
                    Cls_loss.append(cls_loss/len(trainloader))
                    Reg_loss.append(reg_loss/len(trainloader))


                torch.save(model.state_dict(), f'results/cifar_watermark/noReg_{model_name}_{zeta_train}_f{num_epochs}_seed{seed}')
                np.savetxt(f'results/cifar_watermark/trainAcc_zeta_{model_name}{zeta_train}_noReg_{num_epochs}_seed{seed}.txt', Acc_train)
                np.savetxt(f'results/cifar_watermark/testAcc_zeta_{model_name}_{zeta_train}_noReg_{num_epochs}_seed{seed}.txt', Acc_test)
                np.savetxt(f'results/cifar_watermark/clsLoss_zeta_{model_name}_{zeta_train}_noReg_{num_epochs}_seed{seed}.txt', Cls_loss)
                np.savetxt(f'results/cifar_watermark/regLoss_zeta{model_name}_{zeta_train}_noReg_{num_epochs}_seed{seed}.txt', Reg_loss)


In [6]:
if task == "regularization":
    for seed in range(10):
        print(f'==========================Seed: {seed}=============================')
        for model_name in ['resnet', 'vgg', 'alexnet']:
            print(f'========================Model: {model_name}=========================')
            for zeta_train in zeta_trains:   
                print(f'========================Zeta: {zeta_train}===========================')
                torch.manual_seed(seed)
                np.random.seed(seed)

                testset_nonbiased = CIFAR_subset_watermark(testset, index_test, alpha = .5, beta = .5)
                trainset_biased = CIFAR_subset_watermark(trainset, index_train, alpha = zeta_train, beta = zeta_train)

                trainloader = DataLoader(trainset_biased, batch_size=batch_size, num_workers = 8, shuffle=True)
                Loss = []
                if model_name == 'resnet':
                    model = Classifier_resnet(latent_dim).to(device)
                elif model_name == 'vgg':
                    model = Classifier_vgg(latent_dim).to(device)
                elif model_name == 'alexnet':
                    model = Classifier_alexnet(latent_dim).to(device)

                optimizer = optim.SGD(model.parameters(), lr = 5e-4, momentum = 0.9, weight_decay = 1e-4)

                Acc_train = []
                Acc_test = []
                Cls_loss = []
                Reg_loss = []
                criterion = nn.BCELoss()
                model.train()
                for epoch in range(num_epochs):
                    cls_loss = 0.0
                    reg_loss = 0.0
                    acc_train = 0
                    model.train()
                    for idx, data in enumerate(trainloader):
                        images, label = data
                        images = images.permute(0,3,1,2).to(device)
                        label = label[:,0].unsqueeze(1).to(device)
                        optimizer.zero_grad()
                        emb, pred = model(images)

                        if ((data[1][:,1] == 1)*(data[1][:,0] == 0)).sum().item()*((data[1][:,1] == 0)*(data[1][:,0] == 1)).sum().item() == 0:
                            loss_reg = 0
                        else:
                            loss_reg = reg(emb,images, data[1][:,0], data[1][:,1])

                        loss_cls = criterion(pred, label)

                        loss = loss_cls + loss_reg
                        loss.backward()

                        optimizer.step()
                        cls_loss += loss_cls.item()
                        if loss_reg > 0:
                            reg_loss += loss_reg.item()
                        acc_train += ((pred > 0.5)==label).float().sum().item()

                    acc_train/=len(trainset_biased)

                    model.eval()
                    with torch.no_grad():
                        data = testset_nonbiased.data.permute(0,3,1,2).clone().detach().to(device)
                        label = testset_nonbiased.targets.view(-1,1).clone().detach().to(device)
                        _, pred = model(data)
                        acc_test = ((pred > 0.5)==label).float().mean().item()

                    print('Epoch [%d/%d], Loss: [%.4f,%.4f], Acc: [%.4f/%.4f]' % (
                        epoch+1, num_epochs, cls_loss/len(trainloader),
                        reg_loss/len(trainloader), acc_train, acc_test))

                    Acc_test.append(acc_test)
                    Acc_train.append(acc_train)
                    Cls_loss.append(cls_loss/len(trainloader))
                    Reg_loss.append(reg_loss/len(trainloader))


                torch.save(model.state_dict(), f'results/cifar_watermark/Reg_{model_name}_{zeta_train}_{num_epochs}_seed{seed}')
                np.savetxt(f'results/cifar_watermark/trainAcc_zeta_{model_name}{zeta_train}_Reg_{num_epochs}_seed{seed}.txt', Acc_train)
                np.savetxt(f'results/cifar_watermark/testAcc_zeta_{model_name}_{zeta_train}_Reg_{num_epochs}_seed{seed}.txt', Acc_test)
                np.savetxt(f'results/cifar_watermark/clsLoss_zeta_{model_name}_{zeta_train}_Reg_{num_epochs}_seed{seed}.txt', Cls_loss)
                np.savetxt(f'results/cifar_watermark/regLoss_zeta{model_name}_{zeta_train}_Reg_{num_epochs}_seed{seed}.txt', Reg_loss)


In [None]:
if task == "groupDRO":
    
    criterion = nn.BCELoss(reduction = 'none')
    def group_accuracy(model, trainSubloader):
        model.eval()
        with torch.no_grad():
            Acc = []
            for subloader in trainSubloader:
                acc = 0
                ct = 0
                for data in subloader:
                    images, label = data
                    images = images.permute(0,3,1,2).to(device)
                    label = label[:,0].unsqueeze(1).to(device)
                    _, pred = model(images)
                    acc += ((pred > 0.5)==label).float().sum().item()
                    ct += len(images)
                Acc.append(acc/ct)
        return Acc

    def total_accuracy(model, dataloader):
        model.eval()
        acc = 0
        ct = 0
        for data in dataloader:
            images, label = data
            images = images.permute(0,3,1,2).to(device)
            label = label[:,0].unsqueeze(1).to(device)
            _, pred = model(images)
            acc += ((pred > 0.5)==label).float().sum().item()
            ct += len(images)
        return acc/ct


    def train(model, dataloader):
        model.train()
        cls_loss = 0
        reg_loss = 0
        for idx, data in enumerate(dataloader):
            images, target = data
            images = images.permute(0,3,1,2).to(device)
            label = target[:,0].unsqueeze(1).to(device)
            grp = torch.tensor([group_mapping[(target[i,0].item(), target[i,1].item())] for i in range(len(target))])
            grp = grp.to(device)

            emb, pred = model(images)
            loss_cls = loss_fn(pred, label, grp)

            with torch.no_grad():
                if ((data[1][:,1] == 1)*(data[1][:,0] == 0)).sum().item()*((data[1][:,1] == 0)*(data[1][:,0] == 1)).sum().item() == 0:
                    loss_reg = 0
                else:
                    loss_reg = reg(emb,images, data[1][:,0], data[1][:,1])

            loss = loss_cls
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            cls_loss += loss_cls.item()
            if loss_reg>0:
                reg_loss += loss_reg.item()
        return cls_loss, reg_loss

    group_mapping = {
        (1,1): 0, (1,0): 1, (0,1): 2, (0,0):3
    }

    def compute_grp_avg(losses, grp):
        grp_map = (grp == torch.arange(4).unsqueeze(1).long().to(losses.device)).float()
        grp_count = grp_map.sum(1)
        grp_denom = grp_count + (grp_count==0).float() # avoid nans
        grp_loss = (grp_map @ losses.view(-1))/grp_denom
        return grp_loss, grp_count

    def compute_robust_loss(grp_loss, grp_count, step_size = 1e-2):
        adv_probs = torch.ones(4).to(grp_loss.device)/4
        adjusted_loss = grp_loss
        adv_probs = torch.exp(step_size*adjusted_loss.data)
        adv_probs = adv_probs/(adv_probs.sum())
        robust_loss = grp_loss @ adv_probs
        return robust_loss, adv_probs

    def loss_fn(pred, label, grp):
        per_sample_losses = criterion(pred, label)
        grp_loss, grp_count = compute_grp_avg(per_sample_losses, grp)
        grp_acc, grp_count = compute_grp_avg(((pred > 0.5) == label).float(), grp)
        actual_loss, weights = compute_robust_loss(grp_loss, grp_count)
        return actual_loss

    
    for seed in range(10):
        print(f'==========================Seed: {seed}=============================')
        for model_name in ['resnet', 'vgg', 'alexnet']:
            print(f'========================Model: {model_name}=========================')
            for zeta_train in zeta_trains:   
                print(f'========================Zeta: {zeta_train}===========================')
                torch.manual_seed(seed)
                np.random.seed(seed)
                
                testset_nonbiased = CIFAR_subset_watermark(testset, index_test, alpha = .5, beta = .5)
                trainset_biased = CIFAR_subset_watermark(trainset, index_train, alpha = zeta_train, beta = zeta_train)

                trainloader = DataLoader(trainset_biased, batch_size=batch_size, num_workers = 8, shuffle=True)
                testloader = DataLoader(testset_nonbiased, batch_size=batch_size, num_workers = 8, shuffle=False)


                indices_11 = np.where((trainset_biased.targets == 1) * (trainset_biased.spurious == 1))[0]
                indices_10 = np.where((trainset_biased.targets == 1) * (trainset_biased.spurious == 0))[0]
                indices_01 = np.where((trainset_biased.targets == 0) * (trainset_biased.spurious == 1))[0]
                indices_00 = np.where((trainset_biased.targets == 0) * (trainset_biased.spurious == 0))[0]
                trainSubloader = [
                    DataLoader(torch.utils.data.Subset(trainset_biased, indices = indices_11),
                               batch_size=batch_size, num_workers = 8, shuffle=True),
                    DataLoader(torch.utils.data.Subset(trainset_biased, indices = indices_10),
                               batch_size=batch_size, num_workers = 8, shuffle=True),
                    DataLoader(torch.utils.data.Subset(trainset_biased, indices = indices_01),
                               batch_size=batch_size, num_workers = 8, shuffle=True),
                    DataLoader(torch.utils.data.Subset(trainset_biased, indices = indices_00),
                               batch_size=batch_size, num_workers = 8, shuffle=True)
                ]

                if model_name == 'resnet':
                    model = Classifier_resnet(latent_dim).to(device)
                elif model_name == 'vgg':
                    model = Classifier_vgg(latent_dim).to(device)
                elif model_name == 'alexnet':
                    model = Classifier_alexnet(latent_dim).to(device)
                    
                optimizer = optim.SGD(model.parameters(), lr = 1e-3, momentum = 0.9, weight_decay = 5e-4)
                scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer,'min',factor=0.1,patience=5,threshold=0.0001,min_lr=0,eps=1e-08)
                Acc_train = []
                Acc_test = []
                Cls_loss = []
                Reg_loss = []
                for epoch in range(num_epochs):
                    model.eval()

                    cls_loss, reg_loss = train(model, trainloader)
                    Acc = group_accuracy(model, trainSubloader)
                    scheduler.step(cls_loss)
                    worst_group = Acc.index(min(Acc))
                    acc_test = total_accuracy(model, testloader)

                    print('S/M/Z [%d/%s/%d] Epoch [%d/%d], grp [%d], Loss [%.4f/%.4f], Gropu Acc [%.2f,%.2f,%.2f,%.2f] Test: [%.4f]' % (
                        seed, model_name, zeta_train, epoch+1, num_epochs, worst_group, cls_loss/len(trainloader), reg_loss/len(trainloader),
                        Acc[0], Acc[1], Acc[2], Acc[3], acc_test))

                    Acc_train.append(Acc)
                    Acc_test.append(acc_test)
                    Cls_loss.append(cls_loss/len(trainloader))
                    Reg_loss.append(reg_loss/len(trainloader))

                torch.save(model.state_dict(), f'results/cifar_watermark/DRO_{model_name}_{zeta_train}_{num_epochs}_seed{seed}')
                np.savetxt(f'results/cifar_watermark/trainAcc_zeta_{model_name}{zeta_train}_DRO_{num_epochs}_seed{seed}.txt', Acc_train)
                np.savetxt(f'results/cifar_watermark/testAcc_zeta_{model_name}_{zeta_train}_DRO_{num_epochs}_seed{seed}.txt', Acc_test)
                np.savetxt(f'results/cifar_watermark/clsLoss_zeta_{model_name}_{zeta_train}_DRO_{num_epochs}_seed{seed}.txt', Cls_loss)
                np.savetxt(f'results/cifar_watermark/regLoss_zeta{model_name}_{zeta_train}_DRO_{num_epochs}_seed{seed}.txt', Reg_loss)


S/M/Z [0/resnet/0] Epoch [1/200], grp [0], Loss [0.5121/0.1993], Gropu Acc [0.00,0.00,1.00,1.00] Test: [0.5000]
S/M/Z [0/resnet/0] Epoch [2/200], grp [0], Loss [0.4711/0.2737], Gropu Acc [0.00,0.00,1.00,1.00] Test: [0.5000]
S/M/Z [0/resnet/0] Epoch [3/200], grp [0], Loss [0.4494/0.1068], Gropu Acc [0.00,0.20,1.00,1.00] Test: [0.5020]
S/M/Z [0/resnet/0] Epoch [4/200], grp [0], Loss [0.4506/0.0885], Gropu Acc [0.04,0.60,1.00,0.99] Test: [0.5120]
S/M/Z [0/resnet/0] Epoch [5/200], grp [0], Loss [0.4280/0.1913], Gropu Acc [0.17,0.80,1.00,0.97] Test: [0.5640]
S/M/Z [0/resnet/0] Epoch [6/200], grp [0], Loss [0.4256/0.1008], Gropu Acc [0.34,1.00,1.00,0.93] Test: [0.6190]
S/M/Z [0/resnet/0] Epoch [7/200], grp [0], Loss [0.4092/0.1773], Gropu Acc [0.51,1.00,1.00,0.87] Test: [0.6850]
S/M/Z [0/resnet/0] Epoch [8/200], grp [0], Loss [0.3976/0.0000], Gropu Acc [0.62,1.00,1.00,0.83] Test: [0.7210]
S/M/Z [0/resnet/0] Epoch [9/200], grp [0], Loss [0.4001/0.1811], Gropu Acc [0.69,1.00,1.00,0.80] Test: [

S/M/Z [0/resnet/0] Epoch [72/200], grp [3], Loss [0.0067/0.1169], Gropu Acc [1.00,1.00,1.00,1.00] Test: [0.9155]
S/M/Z [0/resnet/0] Epoch [73/200], grp [3], Loss [0.0064/0.0885], Gropu Acc [1.00,1.00,1.00,1.00] Test: [0.9160]
S/M/Z [0/resnet/0] Epoch [74/200], grp [3], Loss [0.0060/0.1242], Gropu Acc [1.00,1.00,1.00,1.00] Test: [0.9160]
S/M/Z [0/resnet/0] Epoch [75/200], grp [3], Loss [0.0057/0.0897], Gropu Acc [1.00,1.00,1.00,1.00] Test: [0.9165]
S/M/Z [0/resnet/0] Epoch [76/200], grp [3], Loss [0.0055/0.0436], Gropu Acc [1.00,1.00,1.00,1.00] Test: [0.9155]
S/M/Z [0/resnet/0] Epoch [77/200], grp [3], Loss [0.0054/0.0000], Gropu Acc [1.00,1.00,1.00,1.00] Test: [0.9150]
S/M/Z [0/resnet/0] Epoch [78/200], grp [3], Loss [0.0052/0.0830], Gropu Acc [1.00,1.00,1.00,1.00] Test: [0.9160]
S/M/Z [0/resnet/0] Epoch [79/200], grp [3], Loss [0.0049/0.0782], Gropu Acc [1.00,1.00,1.00,1.00] Test: [0.9160]
S/M/Z [0/resnet/0] Epoch [80/200], grp [0], Loss [0.0047/0.0457], Gropu Acc [1.00,1.00,1.00,1.00