<a href="https://colab.research.google.com/github/ybw9000/torch_Dtensor_playground/blob/main/torch_Dtensor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [95]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [7]:
from torch.distributed.tensor import distribute_tensor, DTensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import Shard, Replicate, Partial
import torch.distributed.tensor as dtensor

In [12]:
import torch.multiprocessing as mp
import os

In [93]:
def setup_device(rank, world_size):
    # need these otherwise it raises errors during process init
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "56492"
    os.environ["RANK"] = str(rank)  # Set the rank for each process
    os.environ["WORLD_SIZE"] = str(world_size)  # Set the world size

def kshard_gemm(rank, world_size):
    setup_device(rank, world_size)
    with DeviceMesh('cpu', [0, 1]):
        x = dtensor.ones((2, 2), placements=(Shard(1), ))
        y = dtensor.ones((2, 2), placements=(Shard(0), ))
        z = x.matmul(y)
        print(z.placements)  # (Partial(sum),)
        print(z.to_local())  # [1., 1.]
        # need a manual replicate to indicate an all-reduce
        z_ = z.redistribute(placements=(Replicate(), ))
        print(z_.placements)
        print(z_.to_local())  # [2., 2.]

In [28]:
world_size = 2
mp.start_processes(kshard_gemm, args=(world_size, ), nprocs=world_size, join=True, start_method='fork')

(Partial(sum),)
tensor([[1., 1.],
        [1., 1.]])
(Partial(sum),)
tensor([[1., 1.],
        [1., 1.]])
(Replicate(),)(Replicate(),)

tensor([[2., 2.],
        [2., 2.]])tensor([[2., 2.],
        [2., 2.]])



In [73]:
def ffn(weight0, weight1, x, ln_weight, ln_bias):
    out = x.matmul(weight0).relu().matmul(weight1)
    return F.layer_norm(out, x.shape[1:], ln_weight, ln_bias, 1e-5)

In [74]:
x = torch.ones((2, 2))
weight0 = torch.ones((2, 2))
weight1 = torch.ones((2, 2))
ln_weight = torch.ones(x.shape[1:]) / 2
ln_bias = torch.ones(x.shape[1:]) + 0.5

In [75]:
ffn(weight0, weight1, x, ln_weight, ln_bias)

tensor([[1.5000, 1.5000],
        [1.5000, 1.5000]])

In [88]:
def sharded_ffn(rank, world_size, weight0, weight1, x, ln_weight, ln_bias):
    setup_device(rank, world_size)
    with DeviceMesh('cpu', [0, 1]):
        # manual SPMD partition if using from_local api
        # Dweight0 = DTensor.from_local(weight0.reshape(weight0.shape[0], 2, -1)[:, rank, :], placements=(Shard(1), ))
        # Dweight1 = DTensor.from_local(weight1.reshape(2, -1, weight1.shape[-1])[rank, :, :], placements=(Shard(0), ))
        # auto SPMD parition if using distribute_tensor api, NOTE: distribute_tensor only works with leaf tensors, aka not activations
        Dweight0 = dtensor.distribute_tensor(weight0, placements=(Shard(1), ))
        Dweight1 = dtensor.distribute_tensor(weight1, placements=(Shard(0), ))
        print(Dweight0)
        print(Dweight1)
        Dx = DTensor.from_local(x, placements=(Replicate(), ))
        Dln_weight = DTensor.from_local(ln_weight, placements=(Replicate(), ))
        Dln_bias = DTensor.from_local(ln_bias, placements=(Replicate(), ))
        Dout = ffn(Dweight0, Dweight1, Dx, Dln_weight, Dln_bias)
        print(Dout.placements)
        # Dout = Dout.redistribute(placements=(Replicate(), ))
        print(Dout.to_local())

In [87]:
mp.start_processes(sharded_ffn, args=(world_size, weight0, weight1, x, ln_weight, ln_bias), nprocs=world_size, join=True, start_method='fork')



DTensor(local_tensor=tensor([[1.],
        [1.]]), device_mesh=DeviceMesh('cpu', [0, 1]), placements=(Shard(dim=1),))DTensor(local_tensor=tensor([[1.],
        [1.]]), device_mesh=DeviceMesh('cpu', [0, 1]), placements=(Shard(dim=1),))

DTensor(local_tensor=tensor([[1., 1.]]), device_mesh=DeviceMesh('cpu', [0, 1]), placements=(Shard(dim=0),))DTensor(local_tensor=tensor([[1., 1.]]), device_mesh=DeviceMesh('cpu', [0, 1]), placements=(Shard(dim=0),))

(Replicate(),)(Replicate(),)

tensor([[1.5000, 1.5000],
        [1.5000, 1.5000]])
tensor([[1.5000, 1.5000],
        [1.5000, 1.5000]])


In [94]:
def ffn_no_ln(weight0, weight1, x):
    out = x.matmul(weight0).relu().matmul(weight1)
    return out

def sharded_ffn_self_init(rank, world_size):
    setup_device(rank, world_size)
    with DeviceMesh('cpu', [0, 1]):
        # random apis like randn does not work on cpus
        Dx = dtensor.ones((2, 2), placements=(Replicate(), ))
        Dweight0 = dtensor.ones((2, 2), placements=(Shard(1), ))
        Dweight1 = dtensor.ones((2, 2), placements=(Shard(0), ))
        Dout = ffn_no_ln(Dweight0, Dweight1, Dx)
        Dout = Dout.redistribute(placements=(Replicate(), ))
        print(Dout.to_local())

In [85]:
mp.start_processes(sharded_ffn_self_init, args=(world_size, ), nprocs=world_size, join=True, start_method='fork')

tensor([[4., 4.],
        [4., 4.]])tensor([[4., 4.],
        [4., 4.]])



In [90]:
x = torch.randn((2, 2))
weight0 = torch.randn((2, 2))
weight1 = torch.randn((2, 2))
ln_weight = torch.randn(x.shape[1:])
ln_bias = torch.randn(x.shape[1:])

In [91]:
ffn(weight0, weight1, x, ln_weight, ln_bias)

tensor([[1.4708, 0.5261],
        [1.2591, 1.6434]])

In [92]:
mp.start_processes(sharded_ffn, args=(world_size, weight0, weight1, x, ln_weight, ln_bias), nprocs=world_size, join=True, start_method='fork')



DTensor(local_tensor=tensor([[-0.8675],
        [ 0.6228]]), device_mesh=DeviceMesh('cpu', [0, 1]), placements=(Shard(dim=1),))DTensor(local_tensor=tensor([[-1.4091],
        [ 0.8076]]), device_mesh=DeviceMesh('cpu', [0, 1]), placements=(Shard(dim=1),))
DTensor(local_tensor=tensor([[ 1.8077, -0.5340]]), device_mesh=DeviceMesh('cpu', [0, 1]), placements=(Shard(dim=0),))

DTensor(local_tensor=tensor([[-0.2621,  0.6153]]), device_mesh=DeviceMesh('cpu', [0, 1]), placements=(Shard(dim=0),))
(Replicate(),)(Replicate(),)
tensor([[1.4708, 0.5261],
        [1.2591, 1.6434]])

tensor([[1.4708, 0.5261],
        [1.2591, 1.6434]])


In [105]:
class MLP(nn.Module):

    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.up_gemm = nn.Linear(in_dim, hidden_dim)
        self.down_gemm = nn.Linear(hidden_dim, out_dim)
        self.ln = nn.LayerNorm(in_dim)

    def forward(self, x):
        out = self.up_gemm(x)
        out = F.relu(out)
        out = self.down_gemm(out)
        # should be a prenorm yet we put it here to auto propagate an all-reduce
        out = self.ln(out)
        return out + x

    def to_dtensor(self):
        self.up_gemm.weight = nn.Parameter(dtensor.distribute_tensor(self.up_gemm.weight, placements=(Shard(1), )))
        # auto complete by gemini below
        self.up_gemm.bias = nn.Parameter(dtensor.distribute_tensor(self.up_gemm.bias, placements=(Replicate(), )))
        self.down_gemm.weight = nn.Parameter(dtensor.distribute_tensor(self.down_gemm.weight, placements=(Shard(0), )))
        self.down_gemm.bias = nn.Parameter(dtensor.distribute_tensor(self.down_gemm.bias, placements=(Replicate(), )))
        self.ln.weight = nn.Parameter(dtensor.distribute_tensor(self.ln.weight, placements=(Replicate(), )))
        self.ln.bias = nn.Parameter(dtensor.distribute_tensor(self.ln.bias, placements=(Replicate(), )))

In [98]:
def sharded_mlp(rank, world_size, model, x):
    setup_device(rank, world_size)
    with DeviceMesh('cpu', [0, 1]):
        x = dtensor.distribute_tensor(x, placements=(Replicate(), ))
        model.to_dtensor()
        out = model(x)
        print(out.to_local())


In [102]:
x = torch.randn((2, 128))
mlp = MLP(128, 512, 128)

In [103]:
mlp(x)

tensor([[-1.4708,  0.1171,  1.7956, -2.0253, -3.2707, -0.2634, -1.3897,  0.5627,
          1.4990,  1.8911, -0.7884, -2.0117, -0.8167, -0.1332, -1.1860, -1.1104,
         -0.7379, -1.3655, -0.7776, -1.9322, -0.5902,  1.7879,  1.6158,  0.4063,
         -1.5333,  1.5448, -3.2080, -0.9128,  0.2484, -1.6977,  1.2191, -0.8326,
          0.6986,  0.3736,  1.4546,  0.0133, -1.0566, -0.8257,  0.7642,  0.2890,
         -1.7559, -0.0089,  0.1052,  0.4554,  1.7395, -2.4393,  0.8127, -0.7082,
          0.7158,  3.0469, -3.6085, -0.6601,  1.3015,  1.9104, -1.6649,  1.1639,
         -2.9202, -0.6016, -0.6613, -3.0803, -1.4661,  1.8791,  0.3352,  1.0842,
          1.1850, -0.6935,  0.2724, -2.2499, -0.4073,  0.2196, -0.0576, -1.8649,
          0.3465,  0.0938,  2.5835,  0.6649, -1.1765, -1.2938,  0.7223,  1.2641,
         -0.0663,  3.2695,  1.0942,  0.5473,  1.6449, -0.3207, -0.3700,  0.3826,
         -1.0532, -0.4288,  1.0137, -1.0178,  0.5352,  3.4009,  0.1376,  1.1851,
         -1.0735,  1.6331, -

In [104]:
mp.start_processes(sharded_mlp, args=(world_size, mlp, x), nprocs=world_size, join=True, start_method='fork')



tensor([[-1.4708,  0.1171,  1.7956, -2.0253, -3.2707, -0.2634, -1.3897,  0.5627,
          1.4990,  1.8911, -0.7884, -2.0117, -0.8167, -0.1332, -1.1860, -1.1104,
         -0.7379, -1.3655, -0.7776, -1.9322, -0.5902,  1.7879,  1.6158,  0.4063,
         -1.5333,  1.5448, -3.2080, -0.9128,  0.2484, -1.6977,  1.2191, -0.8326,
          0.6986,  0.3736,  1.4546,  0.0133, -1.0566, -0.8257,  0.7642,  0.2890,
         -1.7559, -0.0089,  0.1052,  0.4554,  1.7395, -2.4393,  0.8127, -0.7082,
          0.7158,  3.0469, -3.6085, -0.6601,  1.3015,  1.9104, -1.6649,  1.1639,
         -2.9202, -0.6016, -0.6613, -3.0803, -1.4661,  1.8791,  0.3352,  1.0842,
          1.1850, -0.6935,  0.2724, -2.2499, -0.4073,  0.2196, -0.0576, -1.8649,
          0.3465,  0.0938,  2.5835,  0.6649, -1.1765, -1.2938,  0.7223,  1.2641,
         -0.0663,  3.2695,  1.0942,  0.5473,  1.6449, -0.3207, -0.3700,  0.3826,
         -1.0532, -0.4288,  1.0137, -1.0178,  0.5352,  3.4009,  0.1376,  1.1851,
         -1.0735,  1.6331, -

In [114]:
def my_custom_backend(gm: torch.fx.GraphModule, example_inputs):
    for node in gm.graph.nodes:
        print(node)
    print(example_inputs)
    return gm.forward

@torch.compile(backend=my_custom_backend, )
def ffn_compile(weight0, weight1, x):
    out = x.matmul(weight0).relu().matmul(weight1)
    return out

def sharded_ffn_compile(rank, world_size):
    setup_device(rank, world_size)
    with DeviceMesh('cpu', [0, 1]):
        # random apis like randn does not work on cpus
        Dx = dtensor.ones((2, 2), placements=(Replicate(), ))
        Dweight0 = dtensor.ones((2, 2), placements=(Shard(1), ))
        Dweight1 = dtensor.ones((2, 2), placements=(Shard(0), ))
        Dout = ffn_compile(Dweight0, Dweight1, Dx)
        Dout = Dout.redistribute(placements=(Replicate(), ))
        print(Dout.to_local())

In [115]:
mp.start_processes(sharded_ffn_compile, args=(world_size, ), nprocs=world_size, join=True, start_method='fork')

l_x_
l_weight0_l_x_

l_weight0_
l_weight1_l_weight1_

matmul
matmulrelu
relu

out
outoutput

output
[DTensor(local_tensor=tensor([[1., 1.],
        [1., 1.]]), device_mesh=DeviceMesh('cpu', [0, 1]), placements=(Replicate(),)), DTensor(local_tensor=tensor([[1.],
        [1.]]), device_mesh=DeviceMesh('cpu', [0, 1]), placements=(Shard(dim=1),)), DTensor(local_tensor=tensor([[1., 1.]]), device_mesh=DeviceMesh('cpu', [0, 1]), placements=(Shard(dim=0),))]
[DTensor(local_tensor=tensor([[1., 1.],
        [1., 1.]]), device_mesh=DeviceMesh('cpu', [0, 1]), placements=(Replicate(),)), DTensor(local_tensor=tensor([[1.],
        [1.]]), device_mesh=DeviceMesh('cpu', [0, 1]), placements=(Shard(dim=1),)), DTensor(local_tensor=tensor([[1., 1.]]), device_mesh=DeviceMesh('cpu', [0, 1]), placements=(Shard(dim=0),))]
tensor([[4., 4.],
        [4., 4.]])tensor([[4., 4.],
        [4., 4.]])

