In [1]:
import accelerate
from accelerate import init_empty_weights
from accelerate import load_checkpoint_and_dispatch, infer_auto_device_map
from accelerate.utils import get_balanced_memory


In [2]:
import torch.nn as nn
import torch
import json


In [3]:
input = torch.randn(1, 20)
def get_model():
    return nn.Sequential(
          nn.Linear(20, 20),
          nn.Linear(20, 20),
        )
real_model = get_model()
output = real_model.forward(input)
state_dict = real_model.state_dict()
state_dict

new_dict = {}
for name, tensor in state_dict.items():
    fname = f"model_{name}.pt"
    torch.save(state_dict, fname)
    new_dict[name] = fname

checkpoint_path='index.json'
with open(checkpoint_path, 'w') as f:
    json.dump(new_dict, f, indent=2)
    


In [13]:
import ray
# @ray.remote(num_gpus=1)
class RayModelShard:
    def __init__(self, device_id: int, model, device_map, checkpoint_path):
        self.actor_name = f"shard_{device_id}"
        print(locals())
        import os
#         local_device = os.getenv('CUDA_VISIBLE_DEVICES')
        if os.getenv('CUDA_VISIBLE_DEVICES'):
            del os.environ['CUDA_VISIBLE_DEVICES']
        print('Ray gpu ids', ray.get_gpu_ids())
#         os.environ['CUDA_LAUNCH_BLOCKING']='1'
        self.device_id = device_id
        new_device_map = {
            shard_name: device_id if device == device_id else 'cpu' for shard_name, device in device_map.items()
        }
        print('device_ids', device_id, new_device_map)
        self.device_map = new_device_map
        self.model = model
        print('trying to load the model...')

        self.loaded_model = load_checkpoint_and_dispatch(model, checkpoint_path, device_map=new_device_map)
        print("ran init")

    
    def forward(self, *args, **kwargs):
#         return "ran_forward"
        print('run forward', self.actor_name)
        return self.loaded_model.forward(*args, **kwargs)

RayModelShardActor = ray.remote(num_gpus=1)(RayModelShard)


class Dispatcher:
    def __init__(self, device_map):
        self.device_map = device_map
        
    def dispatch(self, *args, _module**kwargs,)

In [14]:
with init_empty_weights():
    model = get_model()

print(model)
# Get the device map
max_memory = get_balanced_memory(
    model,
)
device_map = infer_auto_device_map(model, max_memory=max_memory)

devices = sorted(set(device_map.values()))

Sequential(
  (0): Linear(in_features=20, out_features=20, bias=True)
  (1): Linear(in_features=20, out_features=20, bias=True)
)


In [15]:
# device_actors = [RayModelShard(device, model, device_map, checkpoint_path="/home/ray/tc-test-bloom/tc-test/index.json") for device in devices]
device_actors = [RayModelShardActor.remote(device, model, device_map, checkpoint_path="/home/ray/tc-test-bloom/tc-test/index.json") for device in devices]


[2m[36m(RayModelShard pid=27475)[0m {'self': <__main__.RayModelShard object at 0x7fc0b6658370>, 'model': Sequential(
[2m[36m(RayModelShard pid=27475)[0m   (0): Linear(in_features=20, out_features=20, bias=True)
[2m[36m(RayModelShard pid=27475)[0m   (1): Linear(in_features=20, out_features=20, bias=True)
[2m[36m(RayModelShard pid=27475)[0m ), 'device_map': {'0': 1, '1': 2}, 'checkpoint_path': '/home/ray/tc-test-bloom/tc-test/index.json', 'device_id': 1}
[2m[36m(RayModelShard pid=27475)[0m Ray gpu ids [0]
[2m[36m(RayModelShard pid=27475)[0m device_ids 1 {'0': 1, '1': 'cpu'}
[2m[36m(RayModelShard pid=27475)[0m trying to load the model...
[2m[36m(RayModelShard pid=27476)[0m {'self': <__main__.RayModelShard object at 0x7f212ba48340>, 'model': Sequential(
[2m[36m(RayModelShard pid=27476)[0m   (0): Linear(in_features=20, out_features=20, bias=True)
[2m[36m(RayModelShard pid=27476)[0m   (1): Linear(in_features=20, out_features=20, bias=True)
[2m[36m(RayModelShar

In [16]:
ray_out = ray.get(device_actors[0].forward.remote(input))
# device_actors[0].forward(input)

In [17]:
ray_out

tensor([[-0.7395, -0.1088, -0.2778,  0.1993, -0.3198,  0.2734, -0.5200,  0.1238,
         -0.0372, -0.1720, -0.0598, -0.1271,  0.0536,  0.1790,  0.1112, -0.4418,
          0.1633, -0.0606,  0.0527,  0.2419]], requires_grad=True)

In [24]:
for c in real_model.__getattr__('0').named_modules():
    print(c)

('', Linear(in_features=20, out_features=20, bias=True))


In [22]:
ray.kill(ray.get_actor('shard4_2'))