In [1]:
import sys
sys.path = ["./", "../examples/", "../", ] + sys.path
#sys.path.append("../")
#sys.path.append("../examples/")
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from torchmetrics import CalibrationError
import joblib

from models.initializer import initialize_torchvision_model, initialize_model
from transforms import initialize_transform
from utils import get_config
import wilds
from wilds.common.grouper import CombinatorialGrouper

In [2]:
def get_dataset_loader(config):
    full_dataset = wilds.get_dataset(
        dataset=config.dataset,
        version=config.version,
        root_dir=config.root_dir,
        download=True,
        split_scheme=config.split_scheme,
        **config.dataset_kwargs)
    eval_transform = initialize_transform(
        transform_name=config.transform,
        config=config,
        dataset=full_dataset,
        is_training=False)
    train_grouper = CombinatorialGrouper(
        dataset=full_dataset,
        groupby_fields=config.groupby_fields)
    
    trn_dset = full_dataset.get_subset(
        "train",
        train_grouper=train_grouper,
        frac=config.frac,
        transform=eval_transform,
        subsample_to_minority=config.subsample)
    trn_loader = DataLoader(
        trn_dset,
        shuffle=False, # Do not shuffle eval datasets
        sampler=None,
        collate_fn=trn_dset.collate,
        batch_size=config.batch_size,
        **config.loader_kwargs)
    tst_dset = full_dataset.get_subset(
        "test",
        train_grouper=train_grouper,
        frac=config.frac,
        transform=eval_transform,
        subsample_to_minority=config.subsample)
    tst_loader = DataLoader(
        tst_dset,
        shuffle=False, # Do not shuffle eval datasets
        sampler=None,
        collate_fn=tst_dset.collate,
        batch_size=config.batch_size,
        **config.loader_kwargs)
    return trn_loader, tst_loader

In [7]:
dataset = "celebA"
config = get_config(dataset, "ERM", "../data")
config.batch_size = 32
trn_loader, tst_loader = get_dataset_loader(config)

params = [{
    'name': "ERM",
    "arch": "resnet50",
    'model_path': "../logs/celebA/erm-resnet50/celebA_seed:0_epoch:last_model.pth",
    'pred_dir': "../logs/celebA/erm-resnet50/",
}]
params += [{
    'name': "ERM DPSGD",
    "arch": "dp_resnet50",
    'model_path': f"../logs/celebA/erm-dp_resnet50-lr1e-3-dpsgd_1e-5_{gamma}_1.0_0.0001/celebA_seed:0_epoch:last_model.pth",
    'pred_dir': f"../logs/celebA/erm-dp_resnet50-lr1e-3-dpsgd_1e-5_{gamma}_1.0_0.0001/",
} for gamma in [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0]]

In [8]:
device = "cuda"

for param in params:
    name, arch, model_path, pred_dir = param['name'], param['arch'], param['model_path'], param['pred_dir']

    if os.path.exists(os.path.join(pred_dir, "preds.pkl")):
        continue
        
    d_out = 2
    model = initialize_torchvision_model(arch, d_out)
    res = torch.load(model_path)['algorithm']
    state_dict = {}
    for k, v in res.items():
        if "dp" in arch:
            state_dict[k.replace("model._module.", "")] = v
        else:
            state_dict[k.replace("model.", "")] = v
    model.load_state_dict(state_dict)
    _ = model.to(device)
    
    ret = {"trn": [], "tst": [], "trny": [], "tsty": []}
    for x, y, _ in tqdm(trn_loader):
        ret["trn"].append(model(x.to(device)).detach().cpu())
        ret["trny"].append(y.detach().cpu())
    ret["trn"] = torch.cat(ret["trn"], dim=0)
    for x, y, _ in tqdm(tst_loader):
        ret["tst"].append(model(x.to(device)).detach().cpu())
        ret["tsty"].append(y.detach().cpu())
    ret["tst"] = torch.cat(ret["tst"], dim=0)
    
    joblib.dump(ret, os.path.join(pred_dir, "preds.pkl"))

  0%|          | 0/5087 [00:00<?, ?it/s]

  0%|          | 0/624 [00:00<?, ?it/s]

  0%|          | 0/5087 [00:00<?, ?it/s]

  0%|          | 0/624 [00:00<?, ?it/s]

In [5]:
results = {}

In [10]:
dataset = "celebA"
config = get_config(dataset, "ERM", "../data")
loader = get_dataset_loader(config)

params = [{
    'name': "ERM",
    "arch": "resnet50",
    'model_path': "../logs/celebA/erm/celebA_seed:0_epoch:last_model.pth",
}, {
    'name': "DRO",
    "arch": "resnet50",
    'model_path': "../logs/celebA/groupDRO/celebA_seed:0_epoch:last_model.pth",
}, {
    'name': "DRO wd",
    "arch": "resnet50",
    'model_path': "../logs/celebA/groupDRO_wd0.1/celebA_seed:0_epoch:last_model.pth",
}, {
    'name': "ERM IW",
    "arch": "resnet18",
    'model_path': "../logs/celebA/erm_reweight/celebA_seed:0_epoch:last_model.pth",
}, {
    'name': "ERM DPSGD",
    "arch": "dp_resnet18",
    'model_path': "../logs/celebA/erm-dp_resnet18-dpsgd_1e-5_1.0_0.1_0.0001/celebA_seed:0_epoch:last_model.pth",
}, {
    'name': "ERM IW DPSGD",
    "arch": "dp_resnet18",
    'model_path': "../logs/celebA/iwerm-dp_resnet18-dpsgd_1e-5_1.0_1.0_0.0001/celebA_seed:0_epoch:last_model.pth",
},
]

In [11]:
device = "cuda"

for param in params:
    name, arch, model_path = param['name'], param['arch'], param['model_path']

    d_out = 2
    model = initialize_torchvision_model(arch, d_out)
    res = torch.load(model_path)['algorithm']
    state_dict = {}
    for k, v in res.items():
        if "dp" in arch:
            state_dict[k.replace("model._module.", "")] = v
        else:
            state_dict[k.replace("model.", "")] = v
    model.load_state_dict(state_dict)
    _ = model.to(device)
    
    proba, truths = [], []
    for x, y, _ in tqdm(loader):
        proba.append(torch.nn.Softmax(dim=1)(model(x.to(device))).detach().cpu())
        truths.append(y)
    proba = torch.cat(proba, dim=0)
    truths = torch.cat(truths)
    
    error = CalibrationError()
    error(proba, truths)
    results[(dataset, name)] = error.compute().item()

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

In [15]:
df = pd.DataFrame.from_dict(results, orient="index")
df

Unnamed: 0,0
"(celebA, ERM)",0.041986
"(celebA, DRO wd)",0.366747
"(celebA, ERM IW)",0.10105
"(celebA, ERM DPSGD)",0.05478
"(celebA, ERM IW DPSGD)",0.055586


In [12]:
results

{('celebA', 'ERM'): 0.04198582097887993,
 ('celebA', 'DRO wd'): 0.36674678325653076,
 ('celebA', 'ERM IW'): 0.10104991495609283,
 ('celebA', 'ERM DPSGD'): 0.05478046089410782,
 ('celebA', 'ERM IW DPSGD'): 0.0555860698223114}

In [3]:
dataset = "utkface"
config = get_config(dataset, "ERM", "../data")
config.download = True
loader = get_dataset_loader(config)

params = [{
    'name': "ERM",
    "arch": "resnet50",
    'model_path': "../logs/utkface/erm-resnet50/UTKFace_seed:0_epoch:last_model.pth",
}, {
    'name': "DRO",
    "arch": "resnet50",
    'model_path': "../logs/utkface/groupDRO-resnet50/UTKFace_seed:0_epoch:last_model.pth",
}, {
    'name': "DRO wd",
    "arch": "resnet50",
    'model_path': "../logs/utkface/groupDRO-resnet50_wd0.1/UTKFace_seed:0_epoch:last_model.pth",
}, {
    'name': "ERM IW",
    "arch": "resnet50",
    'model_path': "../logs/utkface/erm_reweight-resnet50/UTKFace_seed:0_epoch:last_model.pth",
#}, {
#    'name': "ERM IW DPSGD",
#    "arch": "dp_bert-base-uncased",
#    'model_path': "../logs/celebA/iwerm-dp_resnet18-dpsgd_1e-5_1.0_1.0_0.0001/celebA_seed:0_epoch:last_model.pth",
}, {
    'name': "ERM DPSGD",
    "arch": "dp_resnet50",
    'model_path': "../logs/utkface/erm-dp_resnet50-dpsgd_1e-5_0.5_1.0_0.0005/UTKFace_seed:0_epoch:last_model.pth",
}, {
    'name': "ERM IW DPSGD",
    "arch": "dp_resnet50",
    'model_path': "../logs/utkface/weightederm-dp_resnet50-dpsgd_1e-5_0.01_1.0_0.001/UTKFace_seed:0_epoch:last_model.pth",
}, {
    'name': "DRO DPSGD",
    "arch": "dp_resnet50",
    'model_path': "../logs/utkface/groupdro-dp_resnet50-dpsgd_1e-5_0.01_1.0_0.001/UTKFace_seed:0_epoch:last_model.pth",
},
]

Downloading dataset to ../data/UTKFace_v1.0...
You can also download the dataset manually at https://wilds.stanford.edu/downloads.
Downloading  to ../data/UTKFace_v1.0/archive.tar.gz


0Byte [00:00, ?Byte/s]


../data/UTKFace_v1.0/archive.tar.gz may be corrupted. Please try deleting it and rerunning this command.

Exception:  unknown url type: ''
problem with:  ../data/UTKFace_v1.0/39_1_20170116174525125.jpg.chip.jpg
problem with:  ../data/UTKFace_v1.0/61_1_20170109142408075.jpg.chip.jpg
problem with:  ../data/UTKFace_v1.0/61_1_20170109150557335.jpg.chip.jpg


In [6]:
device = "cuda"

for param in params:
    name, arch, model_path = param['name'], param['arch'], param['model_path']
    config.model = arch
    
    
    d_out = 2
    model = initialize_model(config, d_out)
    res = torch.load(model_path)['algorithm']
    state_dict = {}
    for k, v in res.items():
        if "dp" in arch:
            state_dict[k.replace("model._module.", "")] = v
        else:
            state_dict[k.replace("model.", "")] = v
    model.load_state_dict(state_dict)
    _ = model.to(device)
    
    proba, truths = [], []
    for x, y, _ in tqdm(loader):
        proba.append(torch.nn.Softmax(dim=1)(model(x.to(device))).detach().cpu())
        truths.append(y)
    proba = torch.cat(proba, dim=0)
    truths = torch.cat(truths)
    
    error = CalibrationError()
    error(proba, truths)
    results[(dataset, name)] = error.compute().item()

  0%|          | 0/74 [00:00<?, ?it/s]

  0%|          | 0/74 [00:00<?, ?it/s]

  0%|          | 0/74 [00:00<?, ?it/s]

  0%|          | 0/74 [00:00<?, ?it/s]

  0%|          | 0/74 [00:00<?, ?it/s]

  0%|          | 0/74 [00:00<?, ?it/s]

  0%|          | 0/74 [00:00<?, ?it/s]

In [8]:
dataset = "civilcomments"
config = get_config(dataset, "ERM", "../data")

params = [{
    'name': "ERM",
    "arch": "head_bert-base-uncased",
    'model_path': "../logs/civilcomments/erm-head_bert-base-uncased/civilcomments_seed:0_epoch:last_model.pth",
}, {
    'name': "DRO",
    "arch": "head_bert-base-uncased",
    'model_path': "../logs/civilcomments/groupDRO-head_bert-base-uncased/civilcomments_seed:0_epoch:last_model.pth",
}, {
    'name': "DRO wd",
    "arch": "head_bert-base-uncased",
    'model_path': "../logs/civilcomments/groupDRO-head_bert-base-uncased_wd1.0/civilcomments_seed:0_epoch:last_model.pth",
}, {
    'name': "ERM IW",
    "arch": "head_bert-base-uncased",
    'model_path': "../logs/civilcomments/erm_reweight-head_bert-base-uncased/civilcomments_seed:0_epoch:last_model.pth",
#}, {
#    'name': "ERM IW DPSGD",
#    "arch": "dp_bert-base-uncased",
#    'model_path': "../logs/celebA/iwerm-dp_resnet18-dpsgd_1e-5_1.0_1.0_0.0001/celebA_seed:0_epoch:last_model.pth",
}, {
    'name': "ERM IW DPSGD",
    "arch": "dp_bert-base-uncased",
    'model_path': "../logs/civilcomments/weightederm-dp_bert-base-uncased-dpsgd_1e-5_0.5_1.0_0.0001/civilcomments_seed:0_epoch:last_model.pth",
}, {
    'name': "IWERM DPSGD",
    "arch": "dp_bert-base-uncased",
    'model_path': "../logs/civilcomments/iwerm-dp_bert-base-uncased-lr1e-5_dpAdamW_1e-5_0.001_1.0_0.0002/civilcomments_seed:0_epoch:last_model.pth",
}, {
    'name': "DRO DPSGD",
    "arch": "dp_bert-base-uncased",
    'model_path': "../logs/civilcomments/groupdro-dp_bert-base-uncased-lr1e-5_dpAdamW_1e-5_0.001_1.0_0.0002/civilcomments_seed:0_epoch:last_model.pth",
},
]

In [9]:
device = "cuda"

for param in params:
    name, arch, model_path = param['name'], param['arch'], param['model_path']
    config.model = arch
    loader = get_dataset_loader(config)
    
    d_out = 2
    model = initialize_model(config, d_out)
    res = torch.load(model_path)['algorithm']
    state_dict = {}
    for k, v in res.items():
        if "dp" in arch:
            state_dict[k.replace("model._module.", "")] = v
        else:
            state_dict[k.replace("model.", "")] = v
    model.load_state_dict(state_dict)
    _ = model.to(device)
    
    proba, truths = [], []
    counts = 0
    for x, y, _ in tqdm(loader):
        proba.append(torch.nn.Softmax(dim=1)(model(x.to(device))).detach().cpu())
        truths.append(y)
        counts += 1
        if counts == 1000:
            break
    proba = torch.cat(proba, dim=0)
    truths = torch.cat(truths)
    
    error = CalibrationError()
    error(proba, truths)
    results[(dataset, name)] = error.compute().item()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertClassifier: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertClassifier were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.b

  0%|          | 0/8362 [00:00<?, ?it/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertClassifier: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertClassifier were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.b

  0%|          | 0/8362 [00:00<?, ?it/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertClassifier: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertClassifier were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.b

  0%|          | 0/8362 [00:00<?, ?it/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertClassifier: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertClassifier were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.b

  0%|          | 0/8362 [00:00<?, ?it/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertClassifier: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertClassifier were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.b

  0%|          | 0/8362 [00:00<?, ?it/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertClassifier: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertClassifier were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.b

  0%|          | 0/8362 [00:00<?, ?it/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertClassifier: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertClassifier were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.b

  0%|          | 0/8362 [00:00<?, ?it/s]

In [12]:
df = pd.DataFrame.from_dict(results, orient="index")
df.index = pd.MultiIndex.from_tuples(df.index)
df

Unnamed: 0,Unnamed: 1,0
utkface,ERM,0.020405
utkface,DRO,0.060203
utkface,DRO wd,0.068113
utkface,ERM IW,0.01625
utkface,ERM DPSGD,0.0265
utkface,ERM IW DPSGD,0.060824
utkface,DRO DPSGD,0.068309
civilcomments,ERM,0.037293
civilcomments,DRO,0.069121
civilcomments,DRO wd,0.062794


In [7]:
df = pd.DataFrame.from_dict(results, orient="index")
df.index = pd.MultiIndex.from_tuples(df.index)
df

Unnamed: 0,Unnamed: 1,0
utkface,ERM,0.020405
utkface,DRO,0.060203
utkface,DRO wd,0.068113
utkface,ERM IW,0.01625
utkface,ERM DPSGD,0.0265
utkface,ERM IW DPSGD,0.060824
utkface,DRO DPSGD,0.068309
