Skip to content

Commit

Permalink
Merge pull request #290 from Anmol6/fix-rank-1-case
Browse files Browse the repository at this point in the history
Indexing: handle cpu & single-gpu without using multiprocessing & dist. data parallel
  • Loading branch information
okhat committed Jan 14, 2024
2 parents 03fb1be + b9b6004 commit 9389495
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 16 deletions.
22 changes: 15 additions & 7 deletions colbert/indexing/collection_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,23 @@ def _sample_embeddings(self, sampled_pids):
local_sample_embs, doclens = self.encoder.encode_passages(local_sample)

if torch.cuda.is_available():
self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).cuda()
torch.distributed.all_reduce(self.num_sample_embs)
if torch.distributed.is_available() and torch.distributed.is_initialized():
self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).cuda()
torch.distributed.all_reduce(self.num_sample_embs)

avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0
avg_doclen_est = torch.tensor([avg_doclen_est]).cuda()
torch.distributed.all_reduce(avg_doclen_est)

avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0
avg_doclen_est = torch.tensor([avg_doclen_est]).cuda()
torch.distributed.all_reduce(avg_doclen_est)
nonzero_ranks = torch.tensor([float(len(local_sample) > 0)]).cuda()
torch.distributed.all_reduce(nonzero_ranks)
else:
self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).cuda()

avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0
avg_doclen_est = torch.tensor([avg_doclen_est]).cuda()

nonzero_ranks = torch.tensor([float(len(local_sample) > 0)]).cuda()
torch.distributed.all_reduce(nonzero_ranks)
nonzero_ranks = torch.tensor([float(len(local_sample) > 0)]).cuda()
else:
if torch.distributed.is_available() and torch.distributed.is_initialized():
self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).cpu()
Expand Down
2 changes: 2 additions & 0 deletions colbert/infra/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class RunSettings:
total_visible_gpus = torch.cuda.device_count()
gpus: int = DefaultVal(total_visible_gpus)

avoid_fork_if_possible: bool = DefaultVal(False)

@property
def gpus_(self):
value = self.gpus
Expand Down
33 changes: 24 additions & 9 deletions colbert/infra/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@ def __init__(self, callee, run_config=None, return_all=False):
self.nranks = self.run_config.nranks

def launch(self, custom_config, *args):
assert isinstance(custom_config, BaseConfig)
assert isinstance(custom_config, RunSettings)

if self.nranks == 1 and self.run_config.avoid_fork_if_possible:
new_config = type(custom_config).from_existing(custom_config, self.run_config, RunConfig(rank=0))
return_val = run_process_without_mp(self.callee, new_config, *args)
return return_val

return_value_queue = mp.Queue()

rng = random.Random(time.time())
port = str(12355 + rng.randint(0, 1000)) # randomize the port to avoid collision on launching several jobs.

all_procs = []
for new_rank in range(0, self.nranks):
assert isinstance(custom_config, BaseConfig)
assert isinstance(custom_config, RunSettings)

new_config = type(custom_config).from_existing(custom_config, self.run_config, RunConfig(rank=new_rank))

args_ = (self.callee, port, return_value_queue, new_config, *args)
Expand Down Expand Up @@ -88,13 +91,25 @@ def launch(self, custom_config, *args):
return return_values


def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

def run_process_without_mp(callee, config, *args):
set_seed(12345)
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, config.gpus_[:config.nranks]))

with Run().context(config, inherit_config=False):
return_val = callee(config, *args)
torch.cuda.empty_cache()
return return_val

def setup_new_process(callee, port, return_value_queue, config, *args):
print_memory_stats()

random.seed(12345)
np.random.seed(12345)
torch.manual_seed(12345)
torch.cuda.manual_seed_all(12345)
set_seed(12345)

rank, nranks = config.rank, config.nranks

Expand Down

0 comments on commit 9389495

Please sign in to comment.