-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
820a9e7
commit bddb58d
Showing
100 changed files
with
6,185 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from __future__ import absolute_import | ||
from .sampler import * | ||
from .datasequence import Datasequence | ||
from .seqpreprocessor import SeqTrainPreprocessor | ||
from .seqpreprocessor import SeqTestPreprocessor | ||
from .dataloader import get_data | ||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from __future__ import print_function | ||
import os.path as osp | ||
from torch.utils.data import DataLoader | ||
from reid.dataset import get_sequence | ||
from reid.data import seqtransforms as T | ||
from reid.data import SeqTrainPreprocessor | ||
from reid.data import SeqTestPreprocessor | ||
from reid.data import RandomPairSampler | ||
|
||
|
||
def get_data(dataset_name, split_id, data_dir, batch_size, seq_len, seq_srd, workers, train_mode): | ||
|
||
root = osp.join(data_dir, dataset_name) | ||
dataset = get_sequence(dataset_name, root, split_id=split_id, | ||
seq_len=seq_len, seq_srd=seq_srd, num_val=1, download=True) | ||
train_set = dataset.trainval | ||
num_classes = dataset.num_trainval_ids | ||
normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||
|
||
train_processor = SeqTrainPreprocessor(train_set, dataset, seq_len, | ||
transform=T.Compose([T.RectScale(256, 128), | ||
T.RandomHorizontalFlip(), | ||
T.RandomSizedEarser(), | ||
T.ToTensor(), normalizer])) | ||
|
||
query_processor = SeqTestPreprocessor(dataset.query, dataset, seq_len, | ||
transform=T.Compose([T.RectScale(256, 128), | ||
T.ToTensor(), normalizer])) | ||
|
||
gallery_processor = SeqTestPreprocessor(dataset.gallery, dataset, seq_len, | ||
transform=T.Compose([T.RectScale(256, 128), | ||
T.ToTensor(), normalizer])) | ||
|
||
if train_mode == 'cnn_rnn': | ||
train_loader = DataLoader(train_processor, batch_size=batch_size, num_workers=workers, sampler=RandomPairSampler(train_set), pin_memory=True) | ||
elif train_mode == 'cnn': | ||
train_loader = DataLoader(train_processor, batch_size=batch_size, num_workers=workers, shuffle=True, pin_memory=True) | ||
else: | ||
raise ValueError('no such train mode') | ||
|
||
query_loader = DataLoader( | ||
query_processor, batch_size=8, num_workers=workers, shuffle=False, | ||
pin_memory=True) | ||
|
||
gallery_loader = DataLoader( | ||
gallery_processor, batch_size=8, num_workers=workers, shuffle=False, | ||
pin_memory=True) | ||
|
||
return dataset, num_classes, train_loader, query_loader, gallery_loader |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from __future__ import print_function | ||
import os.path as osp | ||
import numpy as np | ||
from utils.serialization import read_json | ||
|
||
|
||
def _pluckseq(identities, indices, seq_len, seq_str): | ||
ret = [] | ||
for index, pid in enumerate(indices): | ||
pid_images = identities[pid] | ||
for camid, cam_images in enumerate(pid_images): | ||
seqall = len(cam_images) | ||
seq_inds = [(start_ind, start_ind + seq_len)\ | ||
for start_ind in range(0, seqall-seq_len, seq_str)] | ||
|
||
if not seq_inds: | ||
seq_inds = [(0, seqall)] | ||
for seq_ind in seq_inds: | ||
ret.append((seq_ind[0], seq_ind[1], pid, index, camid)) | ||
return ret | ||
|
||
|
||
|
||
class Datasequence(object): | ||
def __init__(self, root, split_id= 0): | ||
self.root = root | ||
self.split_id = split_id | ||
self.meta = None | ||
self.split = None | ||
self.train, self.val, self.trainval = [], [], [] | ||
self.query, self.gallery = [], [] | ||
self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 | ||
self.identities = [] | ||
|
||
@property | ||
def images_dir(self): | ||
return osp.join(self.root, 'images') | ||
|
||
def load(self, seq_len, seq_str, num_val=0.3, verbose=True): | ||
splits = read_json(osp.join(self.root, 'splits.json')) | ||
if self.split_id >= len(splits): | ||
raise ValueError("split_id exceeds total splits {}" | ||
.format(len(splits))) | ||
|
||
self.split = splits[self.split_id] | ||
|
||
# Randomly split train / val | ||
trainval_pids = np.asarray(self.split['trainval']) | ||
np.random.shuffle(trainval_pids) | ||
num = len(trainval_pids) | ||
|
||
if isinstance(num_val, float): | ||
num_val = int(round(num * num_val)) | ||
if num_val >= num or num_val < 0: | ||
raise ValueError("num_val exceeds total identities {}" | ||
.format(num)) | ||
|
||
train_pids = sorted(trainval_pids[:-num_val]) | ||
val_pids = sorted(trainval_pids[-num_val:]) | ||
|
||
# comments validation set changes every time it loads | ||
|
||
self.meta = read_json(osp.join(self.root, 'meta.json')) | ||
identities = self.meta['identities'] | ||
self.identities = identities | ||
self.train = _pluckseq(identities, train_pids, seq_len, seq_str) | ||
self.val = _pluckseq(identities, val_pids, seq_len, seq_str) | ||
self.trainval = _pluckseq(identities, trainval_pids, seq_len, seq_str) | ||
self.num_train_ids = len(train_pids) | ||
self.num_val_ids = len(val_pids) | ||
self.num_trainval_ids = len(trainval_pids) | ||
|
||
|
||
|
||
|
||
if verbose: | ||
print(self.__class__.__name__, "dataset loaded") | ||
print(" subset | # ids | # sequences") | ||
print(" ---------------------------") | ||
print(" train | {:5d} | {:8d}" | ||
.format(self.num_train_ids, len(self.train))) | ||
print(" val | {:5d} | {:8d}" | ||
.format(self.num_val_ids, len(self.val))) | ||
print(" trainval | {:5d} | {:8d}" | ||
.format(self.num_trainval_ids, len(self.trainval))) | ||
print(" query | {:5d} | {:8d}" | ||
.format(len(self.split['query']), len(self.split['query']))) | ||
print(" gallery | {:5d} | {:8d}" | ||
.format(len(self.split['gallery']), len(self.split['gallery']))) | ||
|
||
def _check_integrity(self): | ||
return osp.isdir(osp.join(self.root, 'images')) and \ | ||
osp.isfile(osp.join(self.root, 'meta.json')) and \ | ||
osp.isfile(osp.join(self.root, 'splits.json')) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from __future__ import absolute_import | ||
from collections import defaultdict | ||
|
||
import numpy as np | ||
import torch | ||
from torch.utils.data.sampler import ( | ||
Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, | ||
WeightedRandomSampler) | ||
|
||
def No_index(a, b): | ||
assert isinstance(a, list) | ||
return [i for i, j in enumerate(a) if j != b] | ||
|
||
|
||
class RandomIdentitySampler(Sampler): | ||
|
||
def __init__(self, data_source, num_instances=1): | ||
self.data_source = data_source | ||
self.num_instances = num_instances | ||
self.index_dic = defaultdict(list) | ||
for index, (_, pid, _) in enumerate(data_source): | ||
self.index_dic[pid].append(index) | ||
self.pids = list(self.index_dic.keys()) | ||
self.num_samples = len(data_source) | ||
|
||
def __len__(self): | ||
return self.num_samples * self.num_instances | ||
|
||
|
||
def __iter__(self): | ||
indices = torch.randperm(self.num_samples) | ||
ret = [] | ||
for i in indices: | ||
pid = self.pids[i] | ||
t = self.index_dic[pid] | ||
if len(t) >= self.num_instances: | ||
t = np.random.choice(t, size=self.num_instances, replace=False) | ||
else: | ||
t = np.random.choice(t, size=self.num_instances, replace=True) | ||
ret.extend(t) | ||
return iter(ret) | ||
|
||
|
||
class RandomPairSampler(Sampler): | ||
def __init__(self, data_source): | ||
self.data_source = data_source | ||
self.index_pid = defaultdict(int) | ||
self.pid_cam = defaultdict(list) | ||
self.pid_index = defaultdict(list) | ||
self.num_samples = len(data_source) | ||
for index, (_, _, _, pid, cam) in enumerate(data_source): | ||
self.index_pid[index] = pid | ||
self.pid_cam[pid].append(cam) | ||
self.pid_index[pid].append(index) | ||
|
||
def __len__(self): | ||
return self.num_samples * 2 | ||
|
||
def __iter__(self): | ||
indices = torch.randperm(self.num_samples) | ||
ret = [] | ||
for i in indices: | ||
_, _, i_label, i_pid, i_cam = self.data_source[i] | ||
ret.append(i) | ||
pid_i = self.index_pid[i] | ||
cams = self.pid_cam[pid_i] | ||
index = self.pid_index[pid_i] | ||
select_cams = No_index(cams, i_cam) | ||
try: | ||
select_camind = np.random.choice(select_cams) | ||
except ValueError: | ||
print(cams) | ||
print(pid_i) | ||
print(i_label) | ||
select_ind = index[select_camind] | ||
ret.append(select_ind) | ||
|
||
return iter(ret) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from __future__ import absolute_import | ||
import os.path as osp | ||
import torch | ||
from PIL import Image | ||
|
||
|
||
|
||
class SeqTrainPreprocessor(object): | ||
def __init__(self, seqset, dataset, seq_len, transform=None): | ||
super(SeqTrainPreprocessor, self).__init__() | ||
self.seqset = seqset | ||
self.identities = dataset.identities | ||
self.transform = transform | ||
self.seq_len = seq_len | ||
self.root = [dataset.images_dir] | ||
self.root.append(dataset.other_dir) | ||
|
||
def __len__(self): | ||
return len(self.seqset) | ||
|
||
|
||
def __getitem__(self, indices): | ||
if isinstance(indices, (tuple, list)): | ||
return [self._get_single_item(index) for index in indices] | ||
return self._get_single_item(indices) | ||
|
||
def _get_single_item(self, index): | ||
|
||
start_ind, end_ind, pid, label, camid = self.seqset[index] | ||
|
||
imgseq = [] | ||
flowseq = [] | ||
for ind in range(start_ind, end_ind): | ||
fname = self.identities[pid][camid][ind] | ||
fpath_img = osp.join(self.root[0], fname) | ||
imgrgb = Image.open(fpath_img).convert('RGB') | ||
fpath_flow = osp.join(self.root[1], fname) | ||
flowrgb = Image.open(fpath_flow).convert('RGB') | ||
imgseq.append(imgrgb) | ||
flowseq.append(flowrgb) | ||
|
||
while (len(imgseq) < self.seq_len): | ||
imgseq.append(imgrgb) | ||
flowseq.append(flowrgb) | ||
|
||
seq = [imgseq, flowseq] | ||
|
||
if self.transform is not None: | ||
seq = self.transform(seq) | ||
|
||
img_tensor = torch.stack(seq[0], 0) | ||
|
||
flow_tensor = torch.stack(seq[1], 0) | ||
|
||
return img_tensor, flow_tensor, label, camid | ||
|
||
|
||
|
||
class SeqTestPreprocessor(object): | ||
|
||
def __init__(self, seqset, dataset, seq_len, transform=None): | ||
super(SeqTestPreprocessor, self).__init__() | ||
self.seqset = seqset | ||
self.identities = dataset.identities | ||
self.transform = transform | ||
self.seq_len = seq_len | ||
self.root = [dataset.images_dir] | ||
self.root.append(dataset.other_dir) | ||
|
||
def __len__(self): | ||
return len(self.seqset) | ||
|
||
def __getitem__(self, indices): | ||
if isinstance(indices, (tuple, list)): | ||
return [self._get_single_item(index) for index in indices] | ||
return self._get_single_item(indices) | ||
|
||
def _get_single_item(self, index): | ||
|
||
start_ind, end_ind, pid, label, camid = self.seqset[index] | ||
|
||
imgseq = [] | ||
flowseq = [] | ||
for ind in range(start_ind, end_ind): | ||
fname = self.identities[pid][camid][ind] | ||
fpath_img = osp.join(self.root[0], fname) | ||
imgrgb = Image.open(fpath_img).convert('RGB') | ||
fpath_flow = osp.join(self.root[1], fname) | ||
flowrgb = Image.open(fpath_flow).convert('RGB') | ||
imgseq.append(imgrgb) | ||
flowseq.append(flowrgb) | ||
|
||
while (len(imgseq) < self.seq_len): | ||
imgseq.append(imgrgb) | ||
flowseq.append(flowrgb) | ||
|
||
seq = [imgseq, flowseq] | ||
|
||
if self.transform is not None: | ||
seq = self.transform(seq) | ||
|
||
img_tensor = torch.stack(seq[0], 0) | ||
|
||
if len(self.root) == 2: | ||
flow_tensor = torch.stack(seq[1], 0) | ||
else: | ||
flow_tensor = None | ||
|
||
return img_tensor, flow_tensor, pid, camid |
Binary file not shown.
Oops, something went wrong.