We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Which way will we support distributed sampler
From DINO (torch.data)
transform = DataAugmentationDINO( args.global_crops_scale, args.local_crops_scale, args.local_crops_number, ) dataset = datasets.ImageFolder(args.data_path, transform=transform) sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True) data_loader = torch.utils.data.DataLoader( dataset, sampler=sampler, batch_size=args.batch_size_per_gpu, num_workers=args.num_workers, pin_memory=True, drop_last=True, )
From BYOL (tfds) https://github.com/deepmind/deepmind-research/blob/2c7c401024c42c4fb1aa20a8b0471d2e6b480906/byol/utils/dataset.py#L56
class PreprocessMode(enum.Enum): """Preprocessing modes for the dataset.""" PRETRAIN = 1 # Generates two augmented views (random crop + augmentations). LINEAR_TRAIN = 2 # Generates a single random crop. EVAL = 3 # Generates a single center crop. def normalize_images(images: jnp.ndarray) -> jnp.ndarray: """Normalize the image using ImageNet statistics.""" mean_rgb = (0.485, 0.456, 0.406) stddev_rgb = (0.229, 0.224, 0.225) normed_images = images - jnp.array(mean_rgb).reshape((1, 1, 1, 3)) normed_images = normed_images / jnp.array(stddev_rgb).reshape((1, 1, 1, 3)) return normed_images def load(split: Split, *, preprocess_mode: PreprocessMode, batch_dims: Sequence[int], transpose: bool = False, allow_caching: bool = False) -> Generator[Batch, None, None]: """Loads the given split of the dataset.""" start, end = _shard(split, jax.host_id(), jax.host_count()) total_batch_size = np.prod(batch_dims) tfds_split = tfds.core.ReadInstruction( _to_tfds_split(split), from_=start, to=end, unit='abs') ds = tfds.load( 'imagenet2012:5.*.*', split=tfds_split, decoders={'image': tfds.decode.SkipDecoding()}) options = tf.data.Options() options.experimental_threading.private_threadpool_size = 48 options.experimental_threading.max_intra_op_parallelism = 1 if preprocess_mode is not PreprocessMode.EVAL: options.experimental_deterministic = False if jax.host_count() > 1 and allow_caching: # Only cache if we are reading a subset of the dataset. ds = ds.cache() ds = ds.repeat() ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0) else: if split.num_examples % total_batch_size != 0: raise ValueError(f'Test/valid must be divisible by {total_batch_size}') ds = ds.with_options(options) def preprocess_pretrain(example): view1 = _preprocess_image(example['image'], mode=preprocess_mode) view2 = _preprocess_image(example['image'], mode=preprocess_mode) label = tf.cast(example['label'], tf.int32) return {'view1': view1, 'view2': view2, 'labels': label} def preprocess_linear_train(example): image = _preprocess_image(example['image'], mode=preprocess_mode) label = tf.cast(example['label'], tf.int32) return {'images': image, 'labels': label} def preprocess_eval(example): image = _preprocess_image(example['image'], mode=preprocess_mode) label = tf.cast(example['label'], tf.int32) return {'images': image, 'labels': label} if preprocess_mode is PreprocessMode.PRETRAIN: ds = ds.map( preprocess_pretrain, num_parallel_calls=tf.data.experimental.AUTOTUNE) elif preprocess_mode is PreprocessMode.LINEAR_TRAIN: ds = ds.map( preprocess_linear_train, num_parallel_calls=tf.data.experimental.AUTOTUNE) else: ds = ds.map( preprocess_eval, num_parallel_calls=tf.data.experimental.AUTOTUNE) def transpose_fn(batch): # We use the double-transpose-trick to improve performance for TPUs. Note # that this (typically) requires a matching HWCN->NHWC transpose in your # model code. The compiler cannot make this optimization for us since our # data pipeline and model are compiled separately. batch = dict(**batch) if preprocess_mode is PreprocessMode.PRETRAIN: batch['view1'] = tf.transpose(batch['view1'], (1, 2, 3, 0)) batch['view2'] = tf.transpose(batch['view2'], (1, 2, 3, 0)) else: batch['images'] = tf.transpose(batch['images'], (1, 2, 3, 0)) return batch for i, batch_size in enumerate(reversed(batch_dims)): ds = ds.batch(batch_size) if i == 0 and transpose: ds = ds.map(transpose_fn) # NHWC -> HWCN ds = ds.prefetch(tf.data.experimental.AUTOTUNE) yield from tfds.as_numpy(ds)
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Which way will we support distributed sampler
From DINO (torch.data)
From BYOL (tfds) https://github.com/deepmind/deepmind-research/blob/2c7c401024c42c4fb1aa20a8b0471d2e6b480906/byol/utils/dataset.py#L56
The text was updated successfully, but these errors were encountered: