In [19]:
import sys
sys.path.append("../")
sys.path.append("../examples/")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from torchmetrics import CalibrationError

from models.initializer import initialize_torchvision_model, initialize_model
from transforms import initialize_transform
from utils import get_config
import wilds
from wilds.common.grouper import CombinatorialGrouper

In [35]:
def get_dataset_loader(config):
    full_dataset = wilds.get_dataset(
        dataset=config.dataset,
        version=config.version,
        root_dir=config.root_dir,
        download=False,
        split_scheme=config.split_scheme,
        **config.dataset_kwargs)
    eval_transform = initialize_transform(
        transform_name=config.transform,
        config=config,
        dataset=full_dataset,
        is_training=False)
    train_grouper = CombinatorialGrouper(
        dataset=full_dataset,
        groupby_fields=config.groupby_fields)
    tst_dset = full_dataset.get_subset(
        "test",
        train_grouper=train_grouper,
        frac=config.frac,
        transform=eval_transform,
        subsample_to_minority=config.subsample)
    loader = DataLoader(
        tst_dset,
        shuffle=False, # Do not shuffle eval datasets
        sampler=None,
        collate_fn=tst_dset.collate,
        batch_size=config.batch_size,
        **config.loader_kwargs)
    return loader

In [9]:
dataset = "celebA"
config = get_config(dataset, "ERM", "../data")
loader = get_dataset_loader(config)

params = [{
    'name': "ERM",
    "arch": "resnet18",
    'model_path': "../logs/celebA/erm/celebA_seed:0_epoch:last_model.pth",
}, {
    'name': "DRO wd",
    "arch": "resnet18",
    'model_path': "../logs/celebA/groupDRO_wd1.0/celebA_seed:0_epoch:last_model.pth",
}, {
    'name': "ERM IW",
    "arch": "resnet18",
    'model_path': "../logs/celebA/erm_reweight/celebA_seed:0_epoch:last_model.pth",
}, {
    'name': "ERM DPSGD",
    "arch": "dp_resnet18",
    'model_path': "../logs/celebA/erm-dp_resnet18-dpsgd_1e-5_1.0_0.1_0.0001/celebA_seed:0_epoch:last_model.pth",
}, {
    'name': "ERM IW DPSGD",
    "arch": "dp_resnet18",
    'model_path': "../logs/celebA/iwerm-dp_resnet18-dpsgd_1e-5_1.0_1.0_0.0001/celebA_seed:0_epoch:last_model.pth",
},
]

In [10]:
device = "cuda"

results = {}
for param in params:
    name, arch, model_path = param['name'], param['arch'], param['model_path']

    d_out = 2
    model = initialize_torchvision_model(arch, d_out)
    res = torch.load(model_path)['algorithm']
    state_dict = {}
    for k, v in res.items():
        if "dp" in arch:
            state_dict[k.replace("model._module.", "")] = v
        else:
            state_dict[k.replace("model.", "")] = v
    model.load_state_dict(state_dict)
    _ = model.to(device)
    
    proba, truths = [], []
    for x, y, _ in tqdm(loader):
        proba.append(torch.nn.Softmax(dim=1)(model(x.to(device))).detach().cpu())
        truths.append(y)
    proba = torch.cat(proba, dim=0)
    truths = torch.cat(truths)
    
    error = CalibrationError()
    error(proba, truths)
    results[(dataset, name)] = error.compute().item()

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

In [15]:
df = pd.DataFrame.from_dict(results, orient="index")
df

Unnamed: 0,0
"(celebA, ERM)",0.041986
"(celebA, DRO wd)",0.366747
"(celebA, ERM IW)",0.10105
"(celebA, ERM DPSGD)",0.05478
"(celebA, ERM IW DPSGD)",0.055586


In [12]:
results

{('celebA', 'ERM'): 0.04198582097887993,
 ('celebA', 'DRO wd'): 0.36674678325653076,
 ('celebA', 'ERM IW'): 0.10104991495609283,
 ('celebA', 'ERM DPSGD'): 0.05478046089410782,
 ('celebA', 'ERM IW DPSGD'): 0.0555860698223114}

In [58]:
dataset = "utkface"
config = get_config(dataset, "ERM", "../data")
config.download = True

params = [{
    'name': "ERM",
    "arch": "resnet50",
    'model_path': "../logs/utkface/erm-resnet50/UTKFace_seed:0_epoch:last_model.pth",
}, {
    'name': "DRO",
    "arch": "resnet50",
    'model_path': "../logs/utkface/",
}, {
    'name': "DRO wd",
    "arch": "resnet50",
    'model_path': "../logs/civilcomments/groupDRO-head_bert-base-uncased_wd1.0/civilcomments_seed:0_epoch:last_model.pth",
}, {
    'name': "ERM IW",
    "arch": "resnet50",
    'model_path': "../logs/civilcomments/erm_reweight-head_bert-base-uncased/civilcomments_seed:0_epoch:last_model.pth",
#}, {
#    'name': "ERM IW DPSGD",
#    "arch": "dp_bert-base-uncased",
#    'model_path': "../logs/celebA/iwerm-dp_resnet18-dpsgd_1e-5_1.0_1.0_0.0001/celebA_seed:0_epoch:last_model.pth",
}, {
    'name': "ERM IW DPSGD",
    "arch": "dp_bert-base-uncased",
    'model_path': "../logs/civilcomments/weightederm-dp_bert-base-uncased-dpsgd_1e-5_0.5_1.0_0.0001/civilcomments_seed:0_epoch:last_model.pth",
}, {
    'name': "IWERM DPSGD",
    "arch": "dp_bert-base-uncased",
    'model_path': "../logs/civilcomments/iwerm-dp_bert-base-uncased-lr1e-5_dpAdamW_1e-5_0.001_1.0_0.0002/civilcomments_seed:0_epoch:last_model.pth",
}, {
    'name': "DRO DPSGD",
    "arch": "dp_bert-base-uncased",
    'model_path': "../logs/civilcomments/groupdro-dp_bert-base-uncased-lr1e-5_dpAdamW_1e-5_0.001_1.0_0.0002/civilcomments_seed:0_epoch:last_model.pth",
},
]

In [None]:
device = "cuda"

for param in params:
    name, arch, model_path = param['name'], param['arch'], param['model_path']
    config.model = arch
    loader = get_dataset_loader(config)
    
    d_out = 2
    model = initialize_model(config, d_out)
    res = torch.load(model_path)['algorithm']
    state_dict = {}
    for k, v in res.items():
        if "dp" in arch:
            state_dict[k.replace("model._module.", "")] = v
        else:
            state_dict[k.replace("model.", "")] = v
    model.load_state_dict(state_dict)
    _ = model.to(device)
    
    proba, truths = [], []
    counts = 0
    for x, y, _ in tqdm(loader):
        proba.append(torch.nn.Softmax(dim=1)(model(x.to(device))).detach().cpu())
        truths.append(y)
        counts += 1
        if counts == 100:
            break
    proba = torch.cat(proba, dim=0)
    truths = torch.cat(truths)
    
    error = CalibrationError()
    error(proba, truths)
    results[(dataset, name)] = error.compute().item()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertClassifier: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertClassifier were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.w

  0%|          | 0/8362 [00:00<?, ?it/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertClassifier: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertClassifier were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.w

  0%|          | 0/8362 [00:00<?, ?it/s]

In [58]:
dataset = "civilcomments"
config = get_config(dataset, "ERM", "../data")

params = [{
    'name': "ERM",
    "arch": "head_bert-base-uncased",
    'model_path': "../logs/civilcomments/erm-head_bert-base-uncased/civilcomments_seed:0_epoch:last_model.pth",
}, {
    'name': "DRO",
    "arch": "head_bert-base-uncased",
    'model_path': "../logs/civilcomments/groupDRO-head_bert-base-uncased/civilcomments_seed:0_epoch:last_model.pth",
}, {
    'name': "DRO wd",
    "arch": "head_bert-base-uncased",
    'model_path': "../logs/civilcomments/groupDRO-head_bert-base-uncased_wd1.0/civilcomments_seed:0_epoch:last_model.pth",
}, {
    'name': "ERM IW",
    "arch": "head_bert-base-uncased",
    'model_path': "../logs/civilcomments/erm_reweight-head_bert-base-uncased/civilcomments_seed:0_epoch:last_model.pth",
#}, {
#    'name': "ERM IW DPSGD",
#    "arch": "dp_bert-base-uncased",
#    'model_path': "../logs/celebA/iwerm-dp_resnet18-dpsgd_1e-5_1.0_1.0_0.0001/celebA_seed:0_epoch:last_model.pth",
}, {
    'name': "ERM IW DPSGD",
    "arch": "dp_bert-base-uncased",
    'model_path': "../logs/civilcomments/weightederm-dp_bert-base-uncased-dpsgd_1e-5_0.5_1.0_0.0001/civilcomments_seed:0_epoch:last_model.pth",
}, {
    'name': "IWERM DPSGD",
    "arch": "dp_bert-base-uncased",
    'model_path': "../logs/civilcomments/iwerm-dp_bert-base-uncased-lr1e-5_dpAdamW_1e-5_0.001_1.0_0.0002/civilcomments_seed:0_epoch:last_model.pth",
}, {
    'name': "DRO DPSGD",
    "arch": "dp_bert-base-uncased",
    'model_path': "../logs/civilcomments/groupdro-dp_bert-base-uncased-lr1e-5_dpAdamW_1e-5_0.001_1.0_0.0002/civilcomments_seed:0_epoch:last_model.pth",
},
]

In [59]:
device = "cuda"

for param in params:
    name, arch, model_path = param['name'], param['arch'], param['model_path']
    config.model = arch
    loader = get_dataset_loader(config)
    
    d_out = 2
    model = initialize_model(config, d_out)
    res = torch.load(model_path)['algorithm']
    state_dict = {}
    for k, v in res.items():
        if "dp" in arch:
            state_dict[k.replace("model._module.", "")] = v
        else:
            state_dict[k.replace("model.", "")] = v
    model.load_state_dict(state_dict)
    _ = model.to(device)
    
    proba, truths = [], []
    counts = 0
    for x, y, _ in tqdm(loader):
        proba.append(torch.nn.Softmax(dim=1)(model(x.to(device))).detach().cpu())
        truths.append(y)
        counts += 1
        if counts == 100:
            break
    proba = torch.cat(proba, dim=0)
    truths = torch.cat(truths)
    
    error = CalibrationError()
    error(proba, truths)
    results[(dataset, name)] = error.compute().item()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertClassifier: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertClassifier were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.w

  0%|          | 0/8362 [00:00<?, ?it/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertClassifier: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertClassifier were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.w

  0%|          | 0/8362 [00:00<?, ?it/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertClassifier: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertClassifier were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.w

  0%|          | 0/8362 [00:00<?, ?it/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertClassifier: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertClassifier were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.w

  0%|          | 0/8362 [00:00<?, ?it/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertClassifier: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertClassifier were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.w

  0%|          | 0/8362 [00:00<?, ?it/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertClassifier: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertClassifier were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.w

  0%|          | 0/8362 [00:00<?, ?it/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertClassifier: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertClassifier were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.w

  0%|          | 0/8362 [00:00<?, ?it/s]

In [60]:
df = pd.DataFrame.from_dict(results, orient="index")
df.index = pd.MultiIndex.from_tuples(df.index)
df

Unnamed: 0,Unnamed: 1,0
celebA,ERM,0.041986
celebA,DRO wd,0.366747
celebA,ERM IW,0.10105
celebA,ERM DPSGD,0.05478
celebA,ERM IW DPSGD,0.055586
civilcomments,ERM,0.107468
civilcomments,ERM IW,0.1309
civilcomments,DRO,0.150426
civilcomments,DRO wd,0.137926
civilcomments,ERM IW DPSGD,0.131587
