From faf7065a8b98f6f07fe8ee6cc9e017ada75fc51a Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 27 Jul 2021 14:25:43 -0700 Subject: [PATCH 1/7] Default to transforming only x. --- examples/run_expt.py | 5 +++-- examples/transforms.py | 20 ++++++++------------ wilds/datasets/wilds_dataset.py | 19 ++++++++++++++----- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/examples/run_expt.py b/examples/run_expt.py index 4e20afdb..c533c35c 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -155,8 +155,9 @@ def main(): split_scheme=config.split_scheme, **config.dataset_kwargs) - # To implement data augmentation (i.e., have different transforms - # at training time vs. test time), modify these two lines: + # To modify data augmentation, modify the following code block. + # If you want to use transforms that modify both `x` and `y`, + # set `do_transform_y` to True when initializing the `WILDSSubset` below. train_transform = initialize_transform( transform_name=config.transform, config=config, diff --git a/examples/transforms.py b/examples/transforms.py index b1c9b97a..a997110d 100644 --- a/examples/transforms.py +++ b/examples/transforms.py @@ -7,8 +7,9 @@ def initialize_transform(transform_name, config, dataset, is_training): """ - Transforms should take in a single (x, y) - and return (transformed_x, transformed_y). + By default, transforms should take in `x` and return `transformed_x`. + For transforms that take in `(x, y)` and return `(transformed_x, transformed_y)`, + set `do_transform_y` to True when initializing the WILDSSubset. """ if transform_name is None: return None @@ -25,11 +26,6 @@ def initialize_transform(transform_name, config, dataset, is_training): else: raise ValueError(f"{transform_name} not recognized") -def transform_input_only(input_transform): - def transform(x, y): - return input_transform(x), y - return transform - def initialize_bert_transform(config): assert 'bert' in config.model assert config.max_token_length is not None @@ -55,7 +51,7 @@ def transform(text): dim=2) x = torch.squeeze(x, dim=0) # First shape dim is always 1 return x - return transform_input_only(transform) + return transform def getBertTokenizer(model): if model == 'bert-base-uncased': @@ -79,7 +75,7 @@ def initialize_image_base_transform(config, dataset): transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ] transform = transforms.Compose(transform_steps) - return transform_input_only(transform) + return transform def initialize_image_resize_and_center_crop_transform(config, dataset): """ @@ -98,7 +94,7 @@ def initialize_image_resize_and_center_crop_transform(config, dataset): transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) - return transform_input_only(transform) + return transform def initialize_poverty_transform(is_training): if is_training: @@ -115,7 +111,7 @@ def transform_rgb(img): img[:3] = rgb_transform(img[:3][[2,1,0]])[[2,1,0]] return img transform = transforms.Lambda(lambda x: transform_rgb(x)) - return transform_input_only(transform) + return transform else: return None @@ -148,4 +144,4 @@ def random_rotation(x: torch.Tensor) -> torch.Tensor: t_standardize, ] transform = transforms.Compose(transforms_ls) - return transform_input_only(transform) + return transform diff --git a/wilds/datasets/wilds_dataset.py b/wilds/datasets/wilds_dataset.py index 48add021..094c1a28 100644 --- a/wilds/datasets/wilds_dataset.py +++ b/wilds/datasets/wilds_dataset.py @@ -433,11 +433,16 @@ def standard_group_eval(metric, grouper, y_pred, y_true, metadata, aggregate=Tru class WILDSSubset(WILDSDataset): - def __init__(self, dataset, indices, transform): + def __init__(self, dataset, indices, transform, do_transform_y=False): """ - This acts like torch.utils.data.Subset, but on WILDSDatasets. - We pass in transform explicitly because it can potentially vary at - training vs. test time, if we're using data augmentation. + This acts like `torch.utils.data.Subset`, but on `WILDSDatasets`. + We pass in `transform` (which is used for data augmentation) explicitly + because it can potentially vary on the training vs. test subsets. + + `do_transform_y` (bool): When this is false (the default), + `self.transform ` acts only on `x`. + Set this to true if `self.transform` should + operate on `(x,y)` instead of just `x`. """ self.dataset = dataset self.indices = indices @@ -449,11 +454,15 @@ def __init__(self, dataset, indices, transform): if hasattr(dataset, attr_name): setattr(self, attr_name, getattr(dataset, attr_name)) self.transform = transform + self.do_transform_y = do_transform_y def __getitem__(self, idx): x, y, metadata = self.dataset[self.indices[idx]] if self.transform is not None: - x, y = self.transform(x, y) + if self.do_transform_y: + x, y = self.transform(x, y) + else: + x = self.transform(x) return x, y, metadata def __len__(self): From e31ad802f473d1ed4288c2c5e3e5396088c3c3fe Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sat, 31 Jul 2021 13:39:20 -0700 Subject: [PATCH 2/7] Add dependencies --- README.md | 6 ++++-- setup.py | 10 ++++++---- wilds/common/metrics/all_metrics.py | 15 ++++++++++----- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index c16b6eab..c0a04501 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,10 @@ pip install -e . - torch>=1.7.0 - torch-scatter>=2.0.5 - torch-geometric>=1.6.1 +- torchvision>=0.8.2 - tqdm>=4.53.0 +- scikit-learn>=0.20.0 +- scipy>=1.5.4 Running `pip install wilds` or `pip install -e .` will automatically check for and install all of these requirements except for the `torch-scatter` and `torch-geometric` packages, which require a [quick manual install](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html#installation-via-binaries). @@ -63,9 +66,8 @@ These scripts are not part of the installed WILDS package. To use them, you shou git clone git@github.com:p-lambda/wilds.git ``` -To run these scripts, you will need to install these additional dependencies: +To run these scripts, you will also need to install this additional dependency: -- torchvision>=0.8.2 - transformers>=3.5.0 All baseline experiments in the paper were run on Python 3.8.5 and CUDA 10.1. diff --git a/setup.py b/setup.py index 9cd1f596..53003e7a 100644 --- a/setup.py +++ b/setup.py @@ -22,14 +22,16 @@ long_description_content_type="text/markdown", install_requires = [ 'numpy>=1.19.1', + 'ogb>=1.2.6', + 'outdated>=0.2.0', 'pandas>=1.1.0', - 'scikit-learn>=0.20.0', 'pillow>=7.2.0', + 'pytz>=2020.4', 'torch>=1.7.0', - 'ogb>=1.2.6', + 'torchvision>=0.8.2', 'tqdm>=4.53.0', - 'outdated>=0.2.0', - 'pytz>=2020.4', + 'scikit-learn>=0.20.0', + 'scipy>=1.5.4' ], license='MIT', packages=setuptools.find_packages(exclude=['dataset_preprocessing', 'examples', 'examples.models', 'examples.models.bert']), diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index cdc57b27..5c93eda7 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -1,10 +1,10 @@ +import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from torchvision.ops.boxes import box_iou from torchvision.models.detection._utils import Matcher from torchvision.ops import nms, box_convert -import numpy as np -import torch.nn.functional as F from wilds.common.metrics.metric import Metric, ElementwiseMetric, MultiTaskMetric from wilds.common.metrics.loss import ElementwiseLoss from wilds.common.utils import avg_over_groups, minimum, maximum, get_counts @@ -243,12 +243,17 @@ def _accuracy(self, src_boxes,pred_boxes , iou_threshold): total_pred = len(pred_boxes) if total_gt > 0 and total_pred > 0: # Define the matcher and distance matrix based on iou - matcher = Matcher(iou_threshold,iou_threshold,allow_low_quality_matches=False) - match_quality_matrix = box_iou(src_boxes,pred_boxes) + matcher = Matcher( + iou_threshold, + iou_threshold, + allow_low_quality_matches=False) + match_quality_matrix = box_iou( + src_boxes, + pred_boxes) results = matcher(match_quality_matrix) true_positive = torch.count_nonzero(results.unique() != -1) matched_elements = results[results > -1] - #in Matcher, a pred element can be matched only twice + # in Matcher, a pred element can be matched only twice false_positive = ( torch.count_nonzero(results == -1) + (len(matched_elements) - len(matched_elements.unique())) From 061b3c47176d9a2be08ad45e6d2a1802f353f746 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sat, 31 Jul 2021 13:55:50 -0700 Subject: [PATCH 3/7] Only allow n_groups_per_batch and distinct_groups to be set if using a group loader --- wilds/common/data_loaders.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/wilds/common/data_loaders.py b/wilds/common/data_loaders.py index d051878e..0fe44fac 100644 --- a/wilds/common/data_loaders.py +++ b/wilds/common/data_loaders.py @@ -27,6 +27,10 @@ def get_train_loader(loader, dataset, batch_size, - data loader (DataLoader): Data loader. """ if loader == 'standard': + if n_groups_per_batch is not None: + raise ValueError("n_groups_per_batch cannot be specified if the data loader is 'standard'. Consider using a 'group' data loader instead.") + if distinct_groups is not None: + raise ValueError("distinct_groups cannot be specified if the data loader is 'standard'. Consider using a 'group' data loader instead.") if uniform_over_groups is None or not uniform_over_groups: return DataLoader( dataset, From 20b5ee953cf91bf888872f0f4d19bb7302e23607 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sun, 1 Aug 2021 00:22:05 -0700 Subject: [PATCH 4/7] Move data loader validation to run_expt args --- examples/configs/utils.py | 36 +++++++++++++++++++++++++----------- wilds/common/data_loaders.py | 4 ---- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/examples/configs/utils.py b/examples/configs/utils.py index b574bf5e..020bb43f 100644 --- a/examples/configs/utils.py +++ b/examples/configs/utils.py @@ -1,3 +1,4 @@ +import copy from configs.algorithm import algorithm_defaults from configs.model import model_defaults from configs.scheduler import scheduler_defaults @@ -7,41 +8,44 @@ def populate_defaults(config): """Populates hyperparameters with defaults implied by choices of other hyperparameters.""" + + orig_config = copy.deepcopy(config) assert config.dataset is not None, 'dataset must be specified' assert config.algorithm is not None, 'algorithm must be specified' + # implied defaults from choice of dataset config = populate_config( - config, + config, dataset_defaults[config.dataset] ) # implied defaults from choice of split if config.dataset in split_defaults and config.split_scheme in split_defaults[config.dataset]: config = populate_config( - config, + config, split_defaults[config.dataset][config.split_scheme] ) - + # implied defaults from choice of algorithm config = populate_config( - config, + config, algorithm_defaults[config.algorithm] ) # implied defaults from choice of loader config = populate_config( - config, + config, loader_defaults ) # implied defaults from choice of model if config.model: config = populate_config( - config, + config, model_defaults[config.model], ) - + # implied defaults from choice of scheduler if config.scheduler: config = populate_config( - config, + config, scheduler_defaults[config.scheduler] ) @@ -52,12 +56,22 @@ def populate_defaults(config): # basic checks required_fields = [ - 'split_scheme', 'train_loader', 'uniform_over_groups', 'batch_size', 'eval_loader', 'model', 'loss_function', + 'split_scheme', 'train_loader', 'uniform_over_groups', 'batch_size', 'eval_loader', 'model', 'loss_function', 'val_metric', 'val_metric_decreasing', 'n_epochs', 'optimizer', 'lr', 'weight_decay', - ] + ] for field in required_fields: assert getattr(config, field) is not None, f"Must manually specify {field} for this setup." + # data loader validations + # we only raise this error if the train_loader is standard, and + # n_groups_per_batch or distinct_groups are + # specified by the user (instead of populated as a default) + if config.train_loader == 'standard': + if orig_config.n_groups_per_batch is not None: + raise ValueError("n_groups_per_batch cannot be specified if the data loader is 'standard'. Consider using a 'group' data loader instead.") + if orig_config.distinct_groups is not None: + raise ValueError("distinct_groups cannot be specified if the data loader is 'standard'. Consider using a 'group' data loader instead.") + return config def populate_config(config, template: dict, force_compatibility=False): @@ -78,7 +92,7 @@ def populate_config(config, template: dict, force_compatibility=False): d_config[key] = val elif d_config[key] != val and force_compatibility: raise ValueError(f"Argument {key} must be set to {val}") - + else: # config[key] expected to be a kwarg dict for kwargs_key, kwargs_val in val.items(): if kwargs_key not in d_config[key] or d_config[key][kwargs_key] is None: diff --git a/wilds/common/data_loaders.py b/wilds/common/data_loaders.py index 0fe44fac..d051878e 100644 --- a/wilds/common/data_loaders.py +++ b/wilds/common/data_loaders.py @@ -27,10 +27,6 @@ def get_train_loader(loader, dataset, batch_size, - data loader (DataLoader): Data loader. """ if loader == 'standard': - if n_groups_per_batch is not None: - raise ValueError("n_groups_per_batch cannot be specified if the data loader is 'standard'. Consider using a 'group' data loader instead.") - if distinct_groups is not None: - raise ValueError("distinct_groups cannot be specified if the data loader is 'standard'. Consider using a 'group' data loader instead.") if uniform_over_groups is None or not uniform_over_groups: return DataLoader( dataset, From f86f1cd5686757f6baf9803a7f16472277b09143 Mon Sep 17 00:00:00 2001 From: Shiori Sagawa Date: Wed, 4 Aug 2021 11:22:27 -0700 Subject: [PATCH 5/7] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c0a04501..efb0faa0 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ pip install wilds If you have already installed it, please check that you have the latest version: ```bash python -c "import wilds; print(wilds.__version__)" -# This should print "1.2.1". If it doesn't, update by running: +# This should print "1.2.2". If it doesn't, update by running: pip install -U wilds ``` From 9470cb63231effb4581e76b95ce94c6edc4acf37 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Wed, 4 Aug 2021 11:44:07 -0700 Subject: [PATCH 6/7] fix for torchvision v0.10 compatibility --- examples/models/detection/fasterrcnn.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/models/detection/fasterrcnn.py b/examples/models/detection/fasterrcnn.py index 59d736a1..1eb4cd4a 100644 --- a/examples/models/detection/fasterrcnn.py +++ b/examples/models/detection/fasterrcnn.py @@ -27,7 +27,6 @@ from torchvision.models.utils import load_state_dict_from_url from torchvision.ops import misc as misc_nn_ops from torchvision.ops import MultiScaleRoIAlign -from torchvision.models.detection import _utils as det_utils from torchvision.models.detection.anchor_utils import AnchorGenerator from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN from torchvision.models.detection.faster_rcnn import TwoMLPHead @@ -127,11 +126,11 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets) sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0] sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) - box_loss.append(det_utils.smooth_l1_loss( + box_loss.append(F.smooth_l1_loss( pred_bbox_deltas_[sampled_pos_inds], regression_targets_[sampled_pos_inds], beta=1 / 9, - size_average=False, + reduction='sum', ) / (sampled_inds.numel())) objectness_loss.append(F.binary_cross_entropy_with_logits( @@ -226,11 +225,11 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): box_regression_ = box_regression_.reshape(N, -1, 4) - box_loss_ = det_utils.smooth_l1_loss( + box_loss_ = F.smooth_l1_loss( box_regression_[sampled_pos_inds_subset, labels_pos], regression_targets_[sampled_pos_inds_subset], beta=1 / 9, - size_average=False, + reduction='sum', ) box_loss.append(box_loss_ / labels_.numel()) From 88ba842b80075a0cef89e43cf066a1912c90bbb8 Mon Sep 17 00:00:00 2001 From: Shiori Sagawa Date: Wed, 4 Aug 2021 15:30:41 -0700 Subject: [PATCH 7/7] bump up version --- wilds/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wilds/version.py b/wilds/version.py index 0d065bf6..531433b0 100644 --- a/wilds/version.py +++ b/wilds/version.py @@ -4,7 +4,7 @@ import logging from threading import Thread -__version__ = '1.2.1' +__version__ = '1.2.2' try: os.environ['OUTDATED_IGNORE'] = '1'