Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pip install wilds
If you have already installed it, please check that you have the latest version:
```bash
python -c "import wilds; print(wilds.__version__)"
# This should print "1.2.1". If it doesn't, update by running:
# This should print "1.2.2". If it doesn't, update by running:
pip install -U wilds
```

Expand All @@ -50,7 +50,10 @@ pip install -e .
- torch>=1.7.0
- torch-scatter>=2.0.5
- torch-geometric>=1.6.1
- torchvision>=0.8.2
- tqdm>=4.53.0
- scikit-learn>=0.20.0
- scipy>=1.5.4

Running `pip install wilds` or `pip install -e .` will automatically check for and install all of these requirements
except for the `torch-scatter` and `torch-geometric` packages, which require a [quick manual install](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html#installation-via-binaries).
Expand All @@ -63,9 +66,8 @@ These scripts are not part of the installed WILDS package. To use them, you shou
git clone git@github.com:p-lambda/wilds.git
```

To run these scripts, you will need to install these additional dependencies:
To run these scripts, you will also need to install this additional dependency:

- torchvision>=0.8.2
- transformers>=3.5.0

All baseline experiments in the paper were run on Python 3.8.5 and CUDA 10.1.
Expand Down
36 changes: 25 additions & 11 deletions examples/configs/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from configs.algorithm import algorithm_defaults
from configs.model import model_defaults
from configs.scheduler import scheduler_defaults
Expand All @@ -7,41 +8,44 @@
def populate_defaults(config):
"""Populates hyperparameters with defaults implied by choices
of other hyperparameters."""

orig_config = copy.deepcopy(config)
assert config.dataset is not None, 'dataset must be specified'
assert config.algorithm is not None, 'algorithm must be specified'

# implied defaults from choice of dataset
config = populate_config(
config,
config,
dataset_defaults[config.dataset]
)

# implied defaults from choice of split
if config.dataset in split_defaults and config.split_scheme in split_defaults[config.dataset]:
config = populate_config(
config,
config,
split_defaults[config.dataset][config.split_scheme]
)

# implied defaults from choice of algorithm
config = populate_config(
config,
config,
algorithm_defaults[config.algorithm]
)

# implied defaults from choice of loader
config = populate_config(
config,
config,
loader_defaults
)
# implied defaults from choice of model
if config.model: config = populate_config(
config,
config,
model_defaults[config.model],
)

# implied defaults from choice of scheduler
if config.scheduler: config = populate_config(
config,
config,
scheduler_defaults[config.scheduler]
)

Expand All @@ -52,12 +56,22 @@ def populate_defaults(config):

# basic checks
required_fields = [
'split_scheme', 'train_loader', 'uniform_over_groups', 'batch_size', 'eval_loader', 'model', 'loss_function',
'split_scheme', 'train_loader', 'uniform_over_groups', 'batch_size', 'eval_loader', 'model', 'loss_function',
'val_metric', 'val_metric_decreasing', 'n_epochs', 'optimizer', 'lr', 'weight_decay',
]
]
for field in required_fields:
assert getattr(config, field) is not None, f"Must manually specify {field} for this setup."

# data loader validations
# we only raise this error if the train_loader is standard, and
# n_groups_per_batch or distinct_groups are
# specified by the user (instead of populated as a default)
if config.train_loader == 'standard':
if orig_config.n_groups_per_batch is not None:
raise ValueError("n_groups_per_batch cannot be specified if the data loader is 'standard'. Consider using a 'group' data loader instead.")
if orig_config.distinct_groups is not None:
raise ValueError("distinct_groups cannot be specified if the data loader is 'standard'. Consider using a 'group' data loader instead.")

return config

def populate_config(config, template: dict, force_compatibility=False):
Expand All @@ -78,7 +92,7 @@ def populate_config(config, template: dict, force_compatibility=False):
d_config[key] = val
elif d_config[key] != val and force_compatibility:
raise ValueError(f"Argument {key} must be set to {val}")

else: # config[key] expected to be a kwarg dict
for kwargs_key, kwargs_val in val.items():
if kwargs_key not in d_config[key] or d_config[key][kwargs_key] is None:
Expand Down
9 changes: 4 additions & 5 deletions examples/models/detection/fasterrcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from torchvision.models.utils import load_state_dict_from_url
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.detection import _utils as det_utils
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN
from torchvision.models.detection.faster_rcnn import TwoMLPHead
Expand Down Expand Up @@ -127,11 +126,11 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets)
sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)

box_loss.append(det_utils.smooth_l1_loss(
box_loss.append(F.smooth_l1_loss(
pred_bbox_deltas_[sampled_pos_inds],
regression_targets_[sampled_pos_inds],
beta=1 / 9,
size_average=False,
reduction='sum',
) / (sampled_inds.numel()))

objectness_loss.append(F.binary_cross_entropy_with_logits(
Expand Down Expand Up @@ -226,11 +225,11 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):

box_regression_ = box_regression_.reshape(N, -1, 4)

box_loss_ = det_utils.smooth_l1_loss(
box_loss_ = F.smooth_l1_loss(
box_regression_[sampled_pos_inds_subset, labels_pos],
regression_targets_[sampled_pos_inds_subset],
beta=1 / 9,
size_average=False,
reduction='sum',
)
box_loss.append(box_loss_ / labels_.numel())

Expand Down
5 changes: 3 additions & 2 deletions examples/run_expt.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,9 @@ def main():
split_scheme=config.split_scheme,
**config.dataset_kwargs)

# To implement data augmentation (i.e., have different transforms
# at training time vs. test time), modify these two lines:
# To modify data augmentation, modify the following code block.
# If you want to use transforms that modify both `x` and `y`,
# set `do_transform_y` to True when initializing the `WILDSSubset` below.
train_transform = initialize_transform(
transform_name=config.transform,
config=config,
Expand Down
20 changes: 8 additions & 12 deletions examples/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

def initialize_transform(transform_name, config, dataset, is_training):
"""
Transforms should take in a single (x, y)
and return (transformed_x, transformed_y).
By default, transforms should take in `x` and return `transformed_x`.
For transforms that take in `(x, y)` and return `(transformed_x, transformed_y)`,
set `do_transform_y` to True when initializing the WILDSSubset.
"""
if transform_name is None:
return None
Expand All @@ -25,11 +26,6 @@ def initialize_transform(transform_name, config, dataset, is_training):
else:
raise ValueError(f"{transform_name} not recognized")

def transform_input_only(input_transform):
def transform(x, y):
return input_transform(x), y
return transform

def initialize_bert_transform(config):
assert 'bert' in config.model
assert config.max_token_length is not None
Expand All @@ -55,7 +51,7 @@ def transform(text):
dim=2)
x = torch.squeeze(x, dim=0) # First shape dim is always 1
return x
return transform_input_only(transform)
return transform

def getBertTokenizer(model):
if model == 'bert-base-uncased':
Expand All @@ -79,7 +75,7 @@ def initialize_image_base_transform(config, dataset):
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
transform = transforms.Compose(transform_steps)
return transform_input_only(transform)
return transform

def initialize_image_resize_and_center_crop_transform(config, dataset):
"""
Expand All @@ -98,7 +94,7 @@ def initialize_image_resize_and_center_crop_transform(config, dataset):
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
return transform_input_only(transform)
return transform

def initialize_poverty_transform(is_training):
if is_training:
Expand All @@ -115,7 +111,7 @@ def transform_rgb(img):
img[:3] = rgb_transform(img[:3][[2,1,0]])[[2,1,0]]
return img
transform = transforms.Lambda(lambda x: transform_rgb(x))
return transform_input_only(transform)
return transform
else:
return None

Expand Down Expand Up @@ -148,4 +144,4 @@ def random_rotation(x: torch.Tensor) -> torch.Tensor:
t_standardize,
]
transform = transforms.Compose(transforms_ls)
return transform_input_only(transform)
return transform
10 changes: 6 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@
long_description_content_type="text/markdown",
install_requires = [
'numpy>=1.19.1',
'ogb>=1.2.6',
'outdated>=0.2.0',
'pandas>=1.1.0',
'scikit-learn>=0.20.0',
'pillow>=7.2.0',
'pytz>=2020.4',
'torch>=1.7.0',
'ogb>=1.2.6',
'torchvision>=0.8.2',
'tqdm>=4.53.0',
'outdated>=0.2.0',
'pytz>=2020.4',
'scikit-learn>=0.20.0',
'scipy>=1.5.4'
],
license='MIT',
packages=setuptools.find_packages(exclude=['dataset_preprocessing', 'examples', 'examples.models', 'examples.models.bert']),
Expand Down
15 changes: 10 additions & 5 deletions wilds/common/metrics/all_metrics.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops.boxes import box_iou
from torchvision.models.detection._utils import Matcher
from torchvision.ops import nms, box_convert
import numpy as np
import torch.nn.functional as F
from wilds.common.metrics.metric import Metric, ElementwiseMetric, MultiTaskMetric
from wilds.common.metrics.loss import ElementwiseLoss
from wilds.common.utils import avg_over_groups, minimum, maximum, get_counts
Expand Down Expand Up @@ -243,12 +243,17 @@ def _accuracy(self, src_boxes,pred_boxes , iou_threshold):
total_pred = len(pred_boxes)
if total_gt > 0 and total_pred > 0:
# Define the matcher and distance matrix based on iou
matcher = Matcher(iou_threshold,iou_threshold,allow_low_quality_matches=False)
match_quality_matrix = box_iou(src_boxes,pred_boxes)
matcher = Matcher(
iou_threshold,
iou_threshold,
allow_low_quality_matches=False)
match_quality_matrix = box_iou(
src_boxes,
pred_boxes)
results = matcher(match_quality_matrix)
true_positive = torch.count_nonzero(results.unique() != -1)
matched_elements = results[results > -1]
#in Matcher, a pred element can be matched only twice
# in Matcher, a pred element can be matched only twice
false_positive = (
torch.count_nonzero(results == -1) +
(len(matched_elements) - len(matched_elements.unique()))
Expand Down
19 changes: 14 additions & 5 deletions wilds/datasets/wilds_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,11 +433,16 @@ def standard_group_eval(metric, grouper, y_pred, y_true, metadata, aggregate=Tru


class WILDSSubset(WILDSDataset):
def __init__(self, dataset, indices, transform):
def __init__(self, dataset, indices, transform, do_transform_y=False):
"""
This acts like torch.utils.data.Subset, but on WILDSDatasets.
We pass in transform explicitly because it can potentially vary at
training vs. test time, if we're using data augmentation.
This acts like `torch.utils.data.Subset`, but on `WILDSDatasets`.
We pass in `transform` (which is used for data augmentation) explicitly
because it can potentially vary on the training vs. test subsets.

`do_transform_y` (bool): When this is false (the default),
`self.transform ` acts only on `x`.
Set this to true if `self.transform` should
operate on `(x,y)` instead of just `x`.
"""
self.dataset = dataset
self.indices = indices
Expand All @@ -449,11 +454,15 @@ def __init__(self, dataset, indices, transform):
if hasattr(dataset, attr_name):
setattr(self, attr_name, getattr(dataset, attr_name))
self.transform = transform
self.do_transform_y = do_transform_y

def __getitem__(self, idx):
x, y, metadata = self.dataset[self.indices[idx]]
if self.transform is not None:
x, y = self.transform(x, y)
if self.do_transform_y:
x, y = self.transform(x, y)
else:
x = self.transform(x)
return x, y, metadata

def __len__(self):
Expand Down
2 changes: 1 addition & 1 deletion wilds/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from threading import Thread

__version__ = '1.2.1'
__version__ = '1.2.2'

try:
os.environ['OUTDATED_IGNORE'] = '1'
Expand Down