Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions references/video_classification/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import math
import torch
from torch.utils.data import Sampler
import torch.distributed as dist
import torchvision.datasets.video_utils


class DistributedSampler(Sampler):
"""
Extension of DistributedSampler, as discussed in
https://github.com/pytorch/pytorch/issues/23430
"""

def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle

def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))

# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size

# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples

if isinstance(self.dataset, Sampler):
orig_indices = list(iter(self.dataset))
indices = [orig_indices[i] for i in indices]

return iter(indices)

def __len__(self):
return self.num_samples

def set_epoch(self, epoch):
self.epoch = epoch


class UniformClipSampler(torch.utils.data.Sampler):
"""
Samples at most `max_video_clips_per_video` clips for each video, equally spaced
Arguments:
video_clips (VideoClips): video clips to sample from
max_clips_per_video (int): maximum number of clips to be sampled per video
"""
def __init__(self, video_clips, max_clips_per_video):
if not isinstance(video_clips, torchvision.datasets.video_utils.VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, "
"got {}".format(type(video_clips)))
self.video_clips = video_clips
self.max_clips_per_video = max_clips_per_video

def __iter__(self):
idxs = []
s = 0
# select at most max_clips_per_video for each video, uniformly spaced
for c in self.video_clips.clips:
length = len(c)
step = max(length // self.max_clips_per_video, 1)
sampled = torch.arange(length)[::step] + s
s += length
idxs.append(sampled)
idxs = torch.cat(idxs).tolist()
return iter(idxs)

def __len__(self):
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)
47 changes: 47 additions & 0 deletions references/video_classification/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
from bisect import bisect_right


class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer,
milestones,
gamma=0.1,
warmup_factor=1.0 / 3,
warmup_iters=5,
warmup_method="linear",
last_epoch=-1,
):
if not milestones == sorted(milestones):
raise ValueError(
"Milestones should be a list of" " increasing integers. Got {}",
milestones,
)

if warmup_method not in ("constant", "linear"):
raise ValueError(
"Only 'constant' or 'linear' warmup_method accepted"
"got {}".format(warmup_method)
)
self.milestones = milestones
self.gamma = gamma
self.warmup_factor = warmup_factor
self.warmup_iters = warmup_iters
self.warmup_method = warmup_method
super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)

def get_lr(self):
warmup_factor = 1
if self.last_epoch < self.warmup_iters:
if self.warmup_method == "constant":
warmup_factor = self.warmup_factor
elif self.warmup_method == "linear":
alpha = float(self.last_epoch) / self.warmup_iters
warmup_factor = self.warmup_factor * (1 - alpha) + alpha
return [
base_lr *
warmup_factor *
self.gamma ** bisect_right(self.milestones, self.last_epoch)
for base_lr in self.base_lrs
]
Loading