# Tensor Parallelism

# Prepare data

$$
y = xw
$$

x(n, dim)

w(dim, 2dim)

y(n, 2dim)

y_label(n, 2dim)

gradient(dim, 2dim)


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
bs = 2 
row = dim = 4
col = out_dim = dim * 2
x = torch.arange(bs * dim, dtype = torch.float32).reshape(bs, dim) 
w = torch.arange(dim * out_dim, dtype = torch.float32, requires_grad = True).reshape(dim, out_dim)
print(x)
print(w)

In [3]:
y_label = torch.randn(bs, out_dim)
y_pred = torch.randn(bs, out_dim)

# pytorch gradient

先采用pytorch自动求导计算梯度。可以用来检验我们手动求导的准确性

In [4]:
w.retain_grad()
w.grad = None
mse_loss = nn.MSELoss(reduction='mean')

y_pred = x @ w
loss_torch = mse_loss(y_pred, y_label)
loss_torch.backward()
print(loss_torch)
print(w.grad)

# tensor gradient

$$
\begin{align}
y &= xw \\
\mathcal{l} &=||y'-y||^2 = ||y'-xw||^2 \\
w &= \arg\min_{w}||y'-xw||^2 \\
\frac{\partial l}{\partial w} &=\frac{\partial l}{\partial y}\frac{\partial y}{\partial w} =  2x^T(y'-xw)\\
\end{align}
$$

In [5]:
y_pred = x @ w
delta_y = y_pred - y_label

In [6]:
delta_w = x.t() @ delta_y / out_dim # dim
print(delta_w)

In [7]:
w.retain_grad()
w.grad = None
mse_loss = nn.MSELoss(reduction='mean')

y_pred = x @ w
loss_torch = mse_loss(y_pred, y_label)
loss_torch.backward()
print(loss_torch)
print(w.grad)

# Tensor Parallel

In [8]:
import copy
# w.grad = None
w_row = torch.arange(dim * out_dim, dtype = torch.float32).reshape(dim, out_dim)
w_col = torch.arange(dim * out_dim, dtype = torch.float32).reshape(dim, out_dim)

## row-tensor parallel

In [9]:
row_dim = dim // 2
w_row_1 = w_row[:row_dim, :]
w_row_2 = w_row[row_dim:, :]

print(w_row_1.shape)
print(w_row_2.shape)
print(w_row_1)
print(w_row_2)

In [17]:
x_col_1 = x[:, :row_dim]
x_col_2 = x[:, row_dim:]
y_1 = x[:, :row_dim] @ w_row_1
y_2 = x[:, row_dim:] @ w_row_2
print(y_1.shape)
print(y_2.shape)
print((y_1+y_2).shape)

In [11]:
delta_y = (y_1 + y_2 - y_label)

grad_row_1 = x_col_1.t() @ delta_y 
grad_row_2 = x_col_2.t() @ delta_y 

print(grad_row_1.shape)
print(grad_row_2.shape)

In [12]:
grad_row = torch.cat((grad_row_1,grad_row_2), dim = 0) / out_dim
print(grad_row)

## col-tensor parallel

In [13]:
col_dim = out_dim // 2
w_col_1 = w_col[: , :col_dim]
w_col_2 = w_col[: , col_dim:]

print(w_col_1.shape)
print(w_col_2.shape)
print(w_col_1)
print(w_col_2)

In [14]:
y_1 = x @ w_col_1
y_2 = x @ w_col_2
print(y_1.shape)
print(y_2.shape)

In [15]:
y_1_delta = y_1 - y_label[:, col_dim:]
y_2_delta = y_2 - y_label[:, :col_dim]
print(y_1_delta.shape)
print(y_2_delta.shape)

In [16]:
grad_col_1 = x.t() @ y_1_delta
grad_col_2 = x.t() @ y_2_delta 
grad_col = torch.cat( (grad_col_1, grad_col_2), dim = 1) / out_dim
print(grad_col_1.shape)
print(grad_col_2.shape)
print(grad_col.shape)
print(grad_col_1)
print(grad_col_2)
print(grad_col)