Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nr. of devices needed #38

Open
zcobol opened this issue Mar 18, 2024 · 36 comments
Open

Nr. of devices needed #38

zcobol opened this issue Mar 18, 2024 · 36 comments

Comments

@zcobol
Copy link

zcobol commented Mar 18, 2024

Running python run.py on a single Nvidia GPU it fails with ValueError: Number of devices 1 must equal the product of mesh_shape (1, 8)

Can the nr of devices be adjusted to 1 only?

@nickorlabs
Copy link

INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:rank:Initializing mesh for self.local_mesh_config=(1, 8) self.between_hosts_config=(1, 1)...
INFO:rank:Detected 2 devices in mesh
Traceback (most recent call last):
File "/opt/grok-1/run.py", line 72, in
main()
File "/opt/grok-1/run.py", line 63, in main
inference_runner.initialize()
File "/opt/grok-1/runners.py", line 282, in initialize
runner.initialize(
File "/opt/grok-1/runners.py", line 181, in initialize
self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/grok-1/runners.py", line 586, in make_mesh
device_mesh = mesh_utils.create_hybrid_device_mesh(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 373, in create_hybrid_device_mesh
per_granule_meshes = [create_device_mesh(mesh_shape, granule)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 373, in
per_granule_meshes = [create_device_mesh(mesh_shape, granule)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 302, in create_device_mesh
raise ValueError(f'Number of devices {len(devices)} must equal the product '
ValueError: Number of devices 2 must equal the product of mesh_shape (1, 8)

this what you get?

@yarodevuci
Copy link

i did put 1 instead of 8

@yarodevuci
Copy link

I keep getting same error : PermissionError: [WinError 32] The process cannot access the file because it is being used by another process: 'D:\dev\shm\tmpp53ohpcl'

@KHARAPSY
Copy link

INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory INFO:rank:Initializing mesh for self.local_mesh_config=(1, 8) self.between_hosts_config=(1, 1)... INFO:rank:Detected 2 devices in mesh Traceback (most recent call last): File "/opt/grok-1/run.py", line 72, in main() File "/opt/grok-1/run.py", line 63, in main inference_runner.initialize() File "/opt/grok-1/runners.py", line 282, in initialize runner.initialize( File "/opt/grok-1/runners.py", line 181, in initialize self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/grok-1/runners.py", line 586, in make_mesh device_mesh = mesh_utils.create_hybrid_device_mesh( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 373, in create_hybrid_device_mesh per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 373, in per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/groq-1/lib/python3.11/site-packages/jax/experimental/mesh_utils.py", line 302, in create_device_mesh raise ValueError(f'Number of devices {len(devices)} must equal the product ' ValueError: Number of devices 2 must equal the product of mesh_shape (1, 8)

this what you get?

I have the same issues, is there a way to resolve this?

@zRzRzRzRzRzRzR
Copy link

zRzRzRzRzRzRzR commented Mar 18, 2024

same issue even all requirements install. I am using 8 GPUs

@nickorlabs
Copy link

I have 2 GPUs and everything installed ok as well.

@bluevisor
Copy link

in run.py, I changed line 60:
local_mesh_config=(1, 8),
to
local_mesh_config=(1, 1),

(I have 1 3090)

@nickorlabs
Copy link

nickorlabs commented Mar 18, 2024

Ok got a little further this time!

Traceback (most recent call last):
File "/opt/grok-1/run.py", line 72, in
main()
File "/opt/grok-1/run.py", line 63, in main
inference_runner.initialize()
File "/opt/grok-1/runners.py", line 294, in initialize
params = runner.load_or_init(dummy_data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/grok-1/runners.py", line 238, in load_or_init
state = xai_checkpoint.restore(
^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/grok-1/checkpoint.py", line 196, in restore
loaded_tensors = load_tensors(ckpt_shapes_flat, ckpt_path, between_hosts_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/grok-1/checkpoint.py", line 107, in load_tensors
return [f.result() for f in fs]
^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/grok-1/checkpoint.py", line 107, in
return [f.result() for f in fs]
^^^^^^^^^^
File "/opt/anaconda3/envs/groq-1/lib/python3.11/concurrent/futures/_base.py", line 449, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/groq-1/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
File "/opt/anaconda3/envs/groq-1/lib/python3.11/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/grok-1/checkpoint.py", line 72, in fast_unpickle
with copy_to_shm(path) as tmp_path:
File "/opt/anaconda3/envs/groq-1/lib/python3.11/contextlib.py", line 137, in enter
return next(self.gen)
^^^^^^^^^^^^^^
File "/opt/grok-1/checkpoint.py", line 52, in copy_to_shm
shutil.copyfile(file, tmp_path)
File "/opt/anaconda3/envs/groq-1/lib/python3.11/shutil.py", line 269, in copyfile
_fastcopy_sendfile(fsrc, fdst)
File "/opt/anaconda3/envs/groq-1/lib/python3.11/shutil.py", line 158, in _fastcopy_sendfile
raise err from None
File "/opt/anaconda3/envs/groq-1/lib/python3.11/shutil.py", line 144, in _fastcopy_sendfile
sent = os.sendfile(outfd, infd, offset, blocksize)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
OSError: [Errno 28] No space left on device: './checkpoints/ckpt-0/tensor00000_000' -> '/dev/shm/tmpi8_qagu5'

I have 2 Quadro 5000s, I guess we do not have enough vRAM doh.

@bluevisor
Copy link

I'm at the same point, GPT told me /dev/shm is a ramdisk, which means we don't have enough ram, not vram.

I have 64G, not sure how much we need... would 128 be enough?

@nickorlabs
Copy link

I have 128 GB on this rig, with the two cards its like 32 GB, this is why I assumed vRAM. Maybe I could be wrong.

@bluevisor
Copy link

bummer... guess we'll just have to wait for gguf...

@nickorlabs
Copy link

Possibly. I might spin up a runpod, or wait for GGUF, I was reading people needing 8 GPUs.

@thisIsLoading
Copy link

after changing the mesh to (1, 6) i get this error:

INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:rank:Initializing mesh for self.local_mesh_config=(1, 6) self.between_hosts_config=(1, 1)...
INFO:rank:Detected 6 devices in mesh
2024-03-18 15:58:10.001688: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.99). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
INFO:rank:partition rules: <bound method LanguageModelConfig.partition_rules of LanguageModelConfig(model=TransformerConfig(emb_size=6144, key_size=128, num_q_heads=48, num_kv_heads=8, num_layers=64, vocab_size=131072, widening_factor=8, attn_output_multiplier=0.08838834764831845, name=None, num_experts=8, capacity_factor=1.0, num_selected_experts=2, init_scale=1.0, shard_activations=True, data_axis='data', model_axis='model'), vocab_size=131072, pad_token=0, eos_token=2, sequence_len=8192, model_size=6144, embedding_init_scale=1.0, embedding_multiplier_scale=78.38367176906169, output_multiplier_scale=0.5773502691896257, name=None, fprop_dtype=<class 'jax.numpy.bfloat16'>, model_type=None, init_scale_override=None, shard_embeddings=True)>
INFO:rank:(1, 256, 6144)
INFO:rank:(1, 256, 131072)
INFO:rank:State sharding type: <class 'model.TrainingState'>
INFO:rank:(1, 256, 6144)
INFO:rank:(1, 256, 131072)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/loading/PycharmProjects/grok-1/run.py", line 72, in <module>
    main()
  File "/home/loading/PycharmProjects/grok-1/run.py", line 63, in main
    inference_runner.initialize()
  File "/home/loading/PycharmProjects/grok-1/runners.py", line 294, in initialize
    params = runner.load_or_init(dummy_data)
  File "/home/loading/PycharmProjects/grok-1/runners.py", line 235, in load_or_init
    state_shapes = jax.eval_shape(self.init_fn, rng, init_data)
ValueError: One of pjit outputs with pytree key path .params['transformer/decoder_layer_0/moe/linear']['w'] was given the sharding of NamedSharding(mesh=Mesh('data': 1, 'model': 6), spec=PartitionSpec(None, 'data', 'model')), which implies that the global size of its dimension 2 should be divisible by 6, but it is equal to 32768 (full shape: (8, 6144, 32768))

looks like it doesnt like 6 either

@thisIsLoading
Copy link

looks like i have to set

        widening_factor=6,
        num_kv_heads=6,

in the TransformerConfig to the number of devices as well

@nickorlabs
Copy link

get it up and running?

@yarodevuci
Copy link

looks like i have to set

        widening_factor=6,
        num_kv_heads=6,

in the TransformerConfig to the number of devices as well

did it work after?

@thisIsLoading
Copy link

@yarodevuci still downloading weights.

i was under the impression that the test wiould download stuff (looks like i'm spoiled by the huggingface api which does it) will report tomorrow. right now it tells me 17 more hours (dont know why so long, am on 750mbit but magnet download is painfully slow)

@nickorlabs
Copy link

Im seeding (again), took me most the evening last night to download, and I have 2000mbit download

@ad1tyac0des
Copy link

I'm at the same point, GPT told me /dev/shm is a ramdisk, which means we don't have enough ram, not vram.

I have 64G, not sure how much we need... would 128 be enough?

My system has 192GB of RAM, I also encountered same.
OSError: [Errno 28] No space left on device: './checkpoints/ckpt-0/tensor00000_000' -> '/dev/shm/tmpbeofn6hn

@yarodevuci
Copy link

@ad1tyac0des it creates temp folder with over 300GB in it, do you have that space on the hard drive?

@toughcoding
Copy link

Is anybody here who saw live presentation where X developers run it using exact commands or we all trying to test it for them?

@pwxpwxtop
Copy link

坑爹,为了下载它,花费了我一天的心血

@KHARAPSY
Copy link

I succeeded increasing space and get rid of this error
"OSError: [Errno 28] No space left on device: './checkpoints/ckpt-0/tensor00000_000' -> '/dev/shm/tmpi8_qagu5'"

but in exchange to do that I end up with system crashed instead, so I will give up for now. I don't have enough RAM to run Grok-1 neither enough money to upgrade my hardware"

@zRzRzRzRzRzRzR
Copy link

zRzRzRzRzRzRzR commented Mar 19, 2024

same issue even all requirements install. I am using 8 GPUs

I change it to 8 x A100 GPU and it cost 65G memory in per GPU to run this model, The resources required to run this model are a bit large. and the requirement is instealled successfull.

Finally run with this code

AX_TRACEBACK_FILTERING=off python run.py

and its work

image

@ad1tyac0des
Copy link

@ad1tyac0des it creates temp folder with over 300GB in it, do you have that space on the hard drive?

I had about 100GB of storage left, but at the moment when the error occurred, my system's RAM was completely utilized. This seems to be the reason why the program stopped. It looks like the problem was due to the high RAM usage rather than storage space.

@thisIsLoading
Copy link

am at 272/300 gb right now. excitement starts to kick in, lets hope this thing runs.

only having 6x 4090 (144GB VRAM) and 512GB RAM, if this isnt enough to at least run it, regardless of the speed, then something is off

@thisIsLoading
Copy link

ok, got a little further but still no cigar:

(.venv) loading@ai:~/PycharmProjects/grok-1$ python run.py                                                                                                                                           │└───────────────────────────────────────────┴───────────────────────────────────────────┘│      6 netns           [netns]                                                  1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
                                                                                                                                                                                                     │┌─┤net├────────────────────────────────────────────────────────────────────┤‹b eno2 n›├─┐│      7 kworker/0:0-eve [kworker/0:0-events]                                     1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA                                    ││10K                                                      ⣀                             ││      8 kworker/0:0H-ev [kworker/0:0H-events_highpri]                            1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory                                ││                                                         ⣿    ┌─┤Download├───────────┐ ││     10 mm_percpu_wq    [mm_percpu_wq]                                           1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
INFO:rank:Initializing mesh for self.local_mesh_config=(1, 6) self.between_hosts_config=(1, 1)...                                                                                                    ││                                                       ⣴⣷⣿ ⣷  │▼ Byte:     1.89 KiB/s│ ││     11 rcu_tasks_rude_ [rcu_tasks_rude_]                                        1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
INFO:rank:Detected 6 devices in mesh                                                                                                                                                                 ││                                             ⣦   ⣴ ⣄⣤  ⣿⣿⣿ ⣿  │▼ Bit:      15.4 Kibps│ ││     12 rcu_tasks_trace [rcu_tasks_trace]                                        1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
2024-03-19 07:55:00.536833: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.99). Because the driver i││                                             ⣿⣶⣦⣶⣿⣤⣿⣿⣶⣤⣿⣿⣿⣾⣿⣾ │▼ Total:       313 GiB│ ││     13 ksoftirqd/0     [ksoftirqd/0]                                            1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
s older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility pa││                                              ⣿⠻⠟⣿⠻⣿⠻⡿⠻⠻⠻⡿⣿⠻⠻ │                      │ ││     14 rcu_sched       [rcu_sched]                                              1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
ckages.                                                                                                                                                                                              ││                                              ⣿  ⣿ ⣿      ⠻   │▲ Byte:     6.11 KiB/s│ ││     15 migration/0     [migration/0]                                            1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
INFO:rank:partition rules: <bound method LanguageModelConfig.partition_rules of LanguageModelConfig(model=TransformerConfig(emb_size=6144, key_size=128, num_q_heads=48, num_kv_heads=6, num_layers=6││                                              ⠻  ⣿ ⣿          │▲ Bit:      48.5 Kibps│ ││     16 idle_inject/0   [idle_inject/0]                                          1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
4, vocab_size=131072, widening_factor=6, attn_output_multiplier=0.08838834764831845, name=None, num_experts=8, capacity_factor=1.0, num_selected_experts=2, init_scale=1.0, shard_activations=True, d││                                                 ⠈ ⣿          │▲ Total:       182 GiB│ ││     18 cpuhp/0         [cpuhp/0]                                                1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
ata_axis='data', model_axis='model'), vocab_size=131072, pad_token=0, eos_token=2, sequence_len=8192, model_size=6144, embedding_init_scale=1.0, embedding_multiplier_scale=78.38367176906169, output││50K                                                ⣿          └─┤Upload├─────────────┘ ││     19 cpuhp/1         [cpuhp/1]                                                1 root    0.0 ⡀⡀⡀⡀⡀  0.0 │
_multiplier_scale=0.5773502691896257, name=None, fprop_dtype=<class 'jax.numpy.bfloat16'>, model_type=None, init_scale_override=None, shard_embeddings=True)>                                        │└───────────────────────────────────────────────────────────────────────────────────────┘└─┤↑ select ↓├─┤info ↲├─┤terminate├─┤kill├─┤interrupt├─────────────────────────────────────────────┤5/665├─┘
INFO:rank:(1, 256, 6144)                                                                                                                                                                             ├─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
INFO:rank:(1, 256, 131072)                                                                                                                                                                           │Every 2.0s: nvidia-smi                                                                                                                                                   ai: Tue Mar 19 07:56:37 2024
INFO:rank:State sharding type: <class 'model.TrainingState'>                                                                                                                                         │
INFO:rank:(1, 256, 6144)                                                                                                                                                                             │Tue Mar 19 07:56:37 2024
INFO:rank:(1, 256, 131072)                                                                                                                                                                           │+---------------------------------------------------------------------------------------+
INFO:rank:Loading checkpoint at ./checkpoints/ckpt-0                                                                                                                                                 │| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
Traceback (most recent call last):                                                                                                                                                                   │|-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/run.py", line 72, in <module>                                                                                                                           │| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
    main()                                                                                                                                                                                           │| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
  File "/home/loading/PycharmProjects/grok-1/run.py", line 63, in main                                                                                                                               │|                                         |                      |               MIG M. |
    inference_runner.initialize()                                                                                                                                                                    │|=========================================+======================+======================|
  File "/home/loading/PycharmProjects/grok-1/runners.py", line 294, in initialize                                                                                                                    │|   0  NVIDIA GeForce RTX 4090        On  | 00000000:16:00.0 Off |                  Off |
    params = runner.load_or_init(dummy_data)                                                                                                                                                         │|  0%   31C    P8              23W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/runners.py", line 238, in load_or_init                                                                                                                  │|                                         |                      |                  N/A |
    state = xai_checkpoint.restore(                                                                                                                                                                  │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/checkpoint.py", line 218, in restore                                                                                                                    │|   1  NVIDIA GeForce RTX 4090        On  | 00000000:34:00.0 Off |                  Off |
    state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding)                                                                                                            │|  0%   30C    P8              25W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/experimental/multihost_utils.py", line 342, in host_local_array_to_global_array                                  │|                                         |                      |                  N/A |
    out_flat = [                                                                                                                                                                                     │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/experimental/multihost_utils.py", line 343, in <listcomp>                                                        │|   2  NVIDIA GeForce RTX 4090        On  | 00000000:52:00.0 Off |                  Off |
    host_local_array_to_global_array_p.bind(inp, global_mesh=global_mesh,                                                                                                                            │|  0%   30C    P8              25W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 420, in bind                                                                                 │|                                         |                      |                  N/A |
    return self.bind_with_trace(find_top_trace(args), args, params)                                                                                                                                  │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 423, in bind_with_trace                                                                      │|   3  NVIDIA GeForce RTX 4090        On  | 00000000:70:00.0 Off |                  Off |
    out = trace.process_primitive(self, map(trace.full_raise, args), params)                                                                                                                         │|  0%   30C    P8              20W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 913, in process_primitive                                                                    │|                                         |                      |                  N/A |
    return primitive.impl(*tracers, **params)                                                                                                                                                        │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/experimental/multihost_utils.py", line 250, in host_local_array_to_global_array_impl                             │|   4  NVIDIA GeForce RTX 4090        On  | 00000000:AC:00.0 Off |                  Off |
    for d, index in local_sharding.devices_indices_map(arr.shape).items()]                                                                                                                           │|  0%   32C    P8              29W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 110, in devices_indices_map                                                        │|                                         |                      |                  N/A |
    return common_devices_indices_map(self, global_shape)                                                                                                                                            │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 59, in common_devices_indices_map                                                  │|   5  NVIDIA GeForce RTX 4090        On  | 00000000:CA:00.0 Off |                  Off |
    return gspmd_sharding.devices_indices_map(global_shape)                                                                                                                                          │|  0%   30C    P8              16W / 450W |      3MiB / 24564MiB |      0%      Default |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 898, in devices_indices_map                                                        │|                                         |                      |                  N/A |
    return gspmd_sharding_devices_indices_map(self, global_shape)                                                                                                                                    │+-----------------------------------------+----------------------+----------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 826, in gspmd_sharding_devices_indices_map                                         │
    self.shard_shape(global_shape)  # raises a good error message                                                                                                                                    │+---------------------------------------------------------------------------------------+
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 122, in shard_shape                                                                │| Processes:                                                                            |
    return _common_shard_shape(self, global_shape)                                                                                                                                                   │|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
  File "/home/loading/PycharmProjects/grok-1/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 77, in _common_shard_shape                                                         │|        ID   ID                                                             Usage      |
    raise ValueError(                                                                                                                                                                                │|=======================================================================================|
ValueError: Sharding GSPMDSharding({devices=[1,1,6]<=[6]}) implies that array axis 2 is partitioned 6 times, but the dimension size is 32768 (full shape: (8, 6144, 32768), per-dimension tiling fact│|  No running processes found                                                           |
ors: [1, 1, 6] should evenly divide the shape)                                                                                                                                                       │+---------------------------------------------------------------------------------------+
(.venv) loading@ai:~/PycharmProjects/grok-1$  

@malinichev
Copy link

malinichev commented Mar 19, 2024

  • In the file checkpoint.py I'm changing /dev/shm/ to './dev/shm/'
  • in the terminal mkdir -p ./dev/shm/
  • after that, I run python run.py

But my macbook m1 pro with 16/512gb,
freezes after that!
because python eats up all my memory)
I assume that there is not enough memory for the ssd

@xSetech xSetech mentioned this issue Mar 19, 2024
@surak
Copy link

surak commented Mar 19, 2024

am at 272/300 gb right now. excitement starts to kick in, lets hope this thing runs.
only having 6x 4090 (144GB VRAM) and 512GB RAM, if this isnt enough to at least run it, regardless of the speed, then something is off

It is probably not. I have 4 A100 and 512gb per node as well and I am not sure I can run it. It's stuck at loading checkpoints for a while now.

@Christmas-Wong
Copy link

you should install jaxlib for cuda, so that your 8 GPUs can be detected. or you can set local_mesh_config=(1, 1), and grok will run on cpu.

@SamKnightV
Copy link

  • In the file checkpoint.py I'm changing /dev/shm/ to './dev/shm/'
  • in the terminal mkdir -p ./dev/shm/
  • after that, I run python run.py

But my macbook m1 pro with 16/512gb, freezes after that! because python eats up all my memory) I assume that there is not enough memory for the ssd

After did it i got an error zsh: killed python run.py

@surak
Copy link

surak commented Mar 21, 2024

  • In the file checkpoint.py I'm changing /dev/shm/ to './dev/shm/'
  • in the terminal mkdir -p ./dev/shm/
  • after that, I run python run.py

But my macbook m1 pro with 16/512gb, freezes after that! because python eats up all my memory) I assume that there is not enough memory for the ssd
After did it i got an error zsh: killed python run.py

We are talking about machines with 512gb of RAM and hundreds of gb of VRAM not being able to run it, not in a laptop. You will have to wait for a WAY smaller version of it to run in a small machine.

@SamKnightV
Copy link

I have Imac with processor 3,6 GHz 10-Core Intel Core i9 Graphics AMD Radeon Pro 5300 4 GB and memory 16 GB 2667 MHz DDR4 ))) What i need to change?

@surak
Copy link

surak commented Mar 21, 2024

))))

You need a real data center gpu compute node with at least 8 x A100 with 80gb to run grok at this point. I doubt that any quantized version would fit on a Mac anytime soon, but who knows? )))

@Na-Yun1990
Copy link

raise ValueError(f'Number of devices {len(devices)} must equal the product '
ValueError: Number of devices 1 must equal the product of mesh_shape (1, 8)
好吧~.~老马是真的一点不留情面啊。普通人还是老老实实的玩玩grok或者gpt什么的吧。这玩意你设备不达标,他是死活不会让你用的。等老子发达了,再搞他24块4090组个服务器爽爽!

@Na-Yun1990
Copy link

I have Imac with processor 3,6 GHz 10-Core Intel Core i9 Graphics AMD Radeon Pro 5300 4 GB and memory 16 GB 2667 MHz DDR4 ))) What i need to change?

You need to change all.Ordinary civilian equipment cannot run this.Maybe Amazon cloud server can run grok-1. But the price will definitely be high

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests