### Tensor Parallelism

##### Example 1

In [1]:
import torch

In [2]:
inputs = torch.tensor([
    [0, 1, 2, 3],
    [4, 5, 6, 7]
])

In [3]:
weights = torch.tensor([
    [10, 14],
    [11, 15],
    [12, 16],
    [13, 17]
])

In [4]:
outputs = inputs @ weights

In [5]:
outputs.shape

torch.Size([2, 2])

In [24]:
inputs = torch.randn(4, 8)
weights = torch.randn(8, 6)

In [62]:
import torch

In [63]:
inputs.shape, weights.shape

(torch.Size([4, 8]), torch.Size([8, 6]))

In [64]:
outputs = inputs @ weights

Compute **the matrix multiplication operation** using tensor parallelism with **a factor of 2**.

In [65]:
def by_column_parallelism(inputs, weights):
    n_cols = weights.shape[-1]    
    w1, w2 = weights[:, :n_cols//2], weights[:, n_cols//2:]
    out1 = inputs @ w1
    out2 = inputs @ w2
    return torch.cat([out1, out2], dim=-1)

In [66]:
column_output = by_column_parallelism(inputs, weights)

In [67]:
torch.allclose(outputs, column_output)

True

In [68]:
def by_row_parallelism(inputs, weights):
    x_dim_partrition = inputs.shape[-1] // 2
    w_dim_partrition = weights.shape[0] // 2
    x1, x2 = inputs[:, :x_dim_partrition], inputs[:, x_dim_partrition:]
    w1, w2 = weights[:w_dim_partrition, :], weights[w_dim_partrition:, :]
    out1 = x1 @ w1
    out2 = x2 @ w2    
    return out1 + out2

In [69]:
row_output = by_row_parallelism(inputs, weights)

In [70]:
torch.allclose(outputs, row_output)

True

### Cross Entropy

##### Example 0

In [None]:
import torch

In [None]:
a = torch.randn(1, 2, 3, 4)
b = torch.randn(2, 12)

In [None]:
a.shape, b.shape

(torch.Size([1, 2, 3, 4]), torch.Size([2, 12]))

In [None]:
a.numel() == b.numel()

True

Change the shape of `a` same as `b` using PyTorch's built-in

In [None]:
# Use view_as to reshape 'a' to the size of 'b'
c = a.view_as(b)

print(c.size())  # prints: torch.Size([2, 12])

torch.Size([2, 12])


##### Example 1

In [None]:
import torch

**Hint**
- `torch.distributed.ReduceOp.MAX`

In [None]:
def get_vocab_range_for_partition(partition_vocab_size, rank, world_size):
    pass

In [None]:
targets = torch.randint(low=0, high=100, size=(50,))

In [None]:
targets

tensor([26, 35, 89, 40,  3, 96, 26, 60, 65, 99, 51,  5, 81, 24, 29, 74, 82, 53,
        16, 66, 31, 75,  2, 60, 43, 58, 45, 97, 99,  9, 16, 58, 36, 49, 81, 35,
        12, 96, 69, 11, 67, 93, 79, 50,  7, 11, 67,  0, 72, 26])

In [None]:
batch_size = 10
seq_len = 20
vocab_size = 100

In [None]:
parallel_logits = torch.randn(batch_size, seq_len, vocab_size)

In [None]:
rank = 2

In [None]:
world_size = 4

In [None]:
partition_size = 25

In [None]:
class VocabParallelEmbedding(torch.autograd.Function):
    def forward(self, vocab_parallel_logits, target):
        # return values, and indicies
        logits_max, _ = torch.max(vocab_parallel_logits, dim=-1)
        
        torch.distributed.all_reduce(
            logits_max,
            op=torch.distributed.ReduceOp.MAX
        )
        
        # rank = torch.distributed.get_rank()
        # world_size = torch.distributed.get_world_size()
        
        vocab_start_idx, vocab_end_idx = get_vocab_range_for_partition(
            partition_size, rank, world_size
        )
        
        target_mask = (target < vocab_start_idx) | (target >= vocab_end_idx)
        masked_target = target.clone() - vocab_start_idx
        masked_target[target_mask] = 0

In [None]:
xs = torch.arange(4*6).view(4, 6)

In [None]:
xs

tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23]])

In [None]:
torch.max(xs, dim=-1)[0]

tensor([ 5, 11, 17, 23])

In [None]:
xs - (torch.max(xs, dim=-1)[0]).unsqueeze(dim=-1)

tensor([[-5, -4, -3, -2, -1,  0],
        [-5, -4, -3, -2, -1,  0],
        [-5, -4, -3, -2, -1,  0],
        [-5, -4, -3, -2, -1,  0]])

### Data Preprocessing

##### Draft 1

In [None]:
import numpy as np
from typing import List

class CachedDataset:
    def __init__(self, filename):
        self.filename = filename
        self.cache = None
        self.cache_index = {}
    
    def prefetch(self, indices: List[int]):
        if all(i in self.cache_index for i in indices):
            return 
        if not self.cache:
            # Load data into memory
            self.data = np.load(self.filename)
        
        # Get the total size of all samples in indices    
        total_size = sum([self.data[i].size for i in indices])
        
        # Allocate memory for cache
        self.total_size = total_size
        self.cache = np.empty(total_size, dtype=self.data.dtype)
        
        # Copy data into cache
        offset = 0
        for i in indices:
            size = self.data[i].size 
            self.cache[offset:offset+size] = self.data[i]
            self.cache_index[i] = offset
            offset += size
            
    def __getitem__(self, i):
        if i not in self.cache_index:
            self.prefetch([i])
        
        start, stop = self.cache_index[i], self.cache_index[i] + self.data[i].size  
        sample = self.cache[start:stop]
        return sample

In [None]:
# Init dataset 
dataset = CachedDataset("data.npy")

In [None]:
#indices = list(range(10))  
indices = [69, 42]

In [None]:
dataset.prefetch(indices)

In [None]:
len(dataset.data)

100

In [None]:
dataset.cache_index

{69: 0, 42: 32}

In [None]:
sample = dataset[69]

In [None]:
print(sample)

[0.90727821 0.47052171 0.09785945 0.76302124 0.10989286 0.53942689
 0.56296104 0.27903864 0.93956806 0.81588349 0.92999636 0.66565923
 0.73933048 0.27453693 0.50694107 0.54195803 0.71630134 0.11058684
 0.49252249 0.31574857 0.88411237 0.89961832 0.40477919 0.77834166
 0.75873789 0.84388431 0.24626659 0.23231936 0.56750329 0.75355609
 0.17288434 0.65904373]


##### Draft 2

In [None]:
import torch

samples = []

for i in range(100):
    sample = torch.rand(32)  # Random array of 32 floats
    samples.append(sample)

samples = torch.stack(samples)  # Stack all tensor elements in the list

# Save to a file
torch.save(samples, "data.pt")  # PyTorch typically uses the .pt or .pth file extension


In [None]:
import torch
from typing import List

class CachedDataset:
    def __init__(self, filename):
        self.filename = filename
        self.cache = None
        self.cache_index = {}

    def prefetch(self, indices: List[int]):
        if all(i in self.cache_index for i in indices):
            return 
        if self.cache is None:
            # Load data into memory
            self.data = torch.load(self.filename)

        # Get the total size of all samples in indices
        total_size = sum([self.data[i].numel() for i in indices])

        # Allocate memory for cache
        self.total_size = total_size
        self.cache = torch.empty(total_size, dtype=self.data.dtype)

        # Copy data into cache
        offset = 0
        for i in indices:
            size = self.data[i].numel()
            self.cache[offset:offset+size] = self.data[i].view(-1)
            self.cache_index[i] = offset
            offset += size

    def __getitem__(self, i):
        if i not in self.cache_index:
            self.prefetch([i])

        start, stop = self.cache_index[i], self.cache_index[i] + self.data[i].numel()
        sample = self.cache[start:stop]
        return sample.view_as(self.data[i])


In [None]:
# Init dataset 
dataset = CachedDataset("data.pt")

In [None]:
#indices = list(range(10))  
indices = [69, 42]

In [None]:
dataset.prefetch(indices)

In [None]:
len(dataset.data)

100

In [None]:
dataset.cache_index

{69: 0, 42: 32}

In [None]:
sample = dataset[69]

In [None]:
print(sample)

tensor([0.1744, 0.6199, 0.7196, 0.9808, 0.3005, 0.0986, 0.9263, 0.1980, 0.8668,
        0.0573, 0.3137, 0.6258, 0.7518, 0.0397, 0.1570, 0.4208, 0.2324, 0.2188,
        0.7201, 0.1034, 0.9380, 0.0234, 0.5270, 0.3557, 0.2978, 0.3853, 0.1590,
        0.0405, 0.4142, 0.3107, 0.6874, 0.9507])
