Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] the number of embedddings in ManagedCollisionCollection must be a multiple of the number of devices #1591

Open
fangleigit opened this issue Dec 18, 2023 · 3 comments

Comments

@fangleigit
Copy link

when changing the number of embeddings to 4091, and mch_size to 1021 of the code below, it will throw the following exception

ValueError: ShardedTensor global_size property does not match from different ranks! Found global_size=torch.Size([3070]) on rank:0, and global_size=torch.Size([3068]) on rank:1.
ValueError: ShardedTensor global_size property does not match from different ranks! Found global_size=torch.Size([3070]) on rank:0, and global_size=torch.Size([3068]) on rank:1.
Traceback (most recent call last):
  File "test2.py", line 143, in <module>
    spmd_sharing_simulation(ShardingType.ROW_WISE)
  File "test2.py", line 139, in spmd_sharing_simulation
    assert 0 == p.exitcode
AssertionError
import os
from typing import Dict, cast

import multiprocess
import torch
import torch.distributed as dist
import torchrec
from torchrec.distributed.mc_embeddingbag import ManagedCollisionEmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection
from torchrec.modules.mc_modules import (
    DistanceLFU_EvictionPolicy,
    ManagedCollisionCollection,
    ManagedCollisionModule,
    MCHManagedCollisionModule,
)


def preprocess_func(id: torch.Tensor, hash_size: int) -> torch.Tensor:
    return id % hash_size


os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

table_name = "sample"

tables = [
    torchrec.EmbeddingBagConfig(
        name=table_name,
        embedding_dim=64,
        num_embeddings=4096,
        feature_names=[table_name],
        pooling=torchrec.PoolingType.SUM,
    )
]

mcc = ManagedCollisionCollection(
    managed_collision_modules={table_name: cast(
        ManagedCollisionModule,
        MCHManagedCollisionModule(
            zch_size=3070,
            mch_size=1026,
            device="meta",
            eviction_interval=1,
            eviction_policy=DistanceLFU_EvictionPolicy(),
            mch_hash_func=preprocess_func,
        ),
    )},
    embedding_configs=tables,
)

ebc: ManagedCollisionEmbeddingBagCollection = ManagedCollisionEmbeddingBagCollection(
    EmbeddingBagCollection(
        tables=tables,
        device='meta',
    ),
    mcc,
    return_remapped_features=False,
)


def single_rank_execution(
    rank: int,
    world_size: int,
    constraints: Dict[str, ParameterConstraints],
    module: torch.nn.Module,
    backend: str,
) -> None:

    def init_distributed_single_host(
        rank: int,
        world_size: int,
        backend: str,
        # pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type.
    ) -> dist.ProcessGroup:
        os.environ["RANK"] = f"{rank}"
        os.environ["WORLD_SIZE"] = f"{world_size}"
        dist.init_process_group(
            rank=rank, world_size=world_size, backend=backend)
        return dist.group.WORLD

    if backend == "nccl":
        device = torch.device(f"cuda:{rank}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")
    topology = Topology(world_size=world_size, compute_device="cuda")
    pg = init_distributed_single_host(rank, world_size, backend)
    planner = EmbeddingShardingPlanner(
        topology=topology,
        constraints=constraints,
    )
    sharders = [cast(ModuleSharder[torch.nn.Module],
                     ManagedCollisionEmbeddingBagCollectionSharder())]
    plan = planner.collective_plan(module, sharders=None, pg=pg)

    sharded_model = DistributedModelParallel(
        module,
        env=ShardingEnv.from_process_group(pg),
        plan=plan,
        sharders=sharders,
        device=device,
    )
    print(f"rank:{rank},sharding plan: {plan}")
    return sharded_model


def spmd_sharing_simulation(
    sharding_type: ShardingType = ShardingType.TABLE_WISE,
    world_size=2,
):
    ctx = multiprocess.get_context("spawn")
    processes = []
    for rank in range(world_size):
        p = ctx.Process(
            target=single_rank_execution,
            args=(
                rank,
                world_size,
                {
                    table_name: ParameterConstraints(
                        sharding_types=[sharding_type.value],
                    )
                },
                ebc,
                "nccl"
            ),
        )
        p.start()
        processes.append(p)

    for p in processes:
        p.join()
        assert 0 == p.exitcode


if __name__ == '__main__':
    spmd_sharing_simulation(ShardingType.ROW_WISE)
@henrylhtsang
Copy link
Contributor

Hi, thanks for trying out ManagedCollisionCollection!

Not sure if its a bug. The thing is, we are trying to (only) use ManagedCollisionCollection with rowwise sharding, which would shard the table evenly to all the gpus, hence the divisible thing.

@fangleigit
Copy link
Author

fangleigit commented Dec 19, 2023

Hi, thanks for trying out ManagedCollisionCollection!

Not sure if its a bug. The thing is, we are trying to (only) use ManagedCollisionCollection with rowwise sharding, which would shard the table evenly to all the gpus, hence the divisible thing.

Thanks for your quick response, yes, I tried ManagedCollisionCollection on our data, the performance degraded when using ManagedCollisionCollection. The training time is also significant increased. Is there any guideline or document on how to set the hyper-parameters when using this module, e.g., eviction_interval, zch_size, mch_size, and which policy is better DistanceLFU_EvictionPolicy or LFU_EvictionPolicy under which scenario.

@henrylhtsang
Copy link
Contributor

@fangleigit Thanks. We are still actively developing MCH/ZCH, so we don't have a clear answer so far. Let us know if you have it figured out as well!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants