# Distributed training with QPS optimizations

In this notebook we look at a range of techniques to speed up the training of a transformer model. We will then adopt distributed data parallel and full use of the 8 GPUs on a single node. We will work with shakespeare texts. 

# Prepare data

In [1]:
import torch


In [2]:
with open('shakespeare.txt', 'r') as f:
    text = f.read()
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [3]:
chars = sorted(list(set(text)))
char_to_ix = {ch:i for i,ch in enumerate(chars)}
ix_to_char = {i:ch for i,ch in enumerate(chars)}

def encode(text):
    return [char_to_ix[ch] for ch in text]

def decode(vec):
    return ''.join([ix_to_char[ix] for ix in vec])

print(''.join(chars))


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [4]:
text_vec = encode(text)
data = torch.tensor(text_vec, dtype=torch.long)
print(data.shape)
print(data[:100])

torch.Size([1115394])
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


In [5]:
n = int(0.9 * len(data))
train_data, val_data = data[:n], data[n:]

In [6]:
torch.manual_seed(1337)

def get_batch(split, batch_size, block_size):
    """
    Each data is multiple examples
    """
    if split == 'train':
        data = train_data
    else:
        # (N x 1)
        data = val_data
    # B x 1
    start = torch.randint(0, len(data) - block_size, (batch_size,))
    # [1, 2, 3, 4, 5], block_size = 3
    # [[1, 2, 3], [2, 3, 4], [3, 4, 5]]
    x = torch.stack([data[s:s+block_size] for s in start])
    # [[2, 3, 4], [3, 4, 5], [4, 5, 6]]
    y = torch.stack([data[s+1:s+block_size+1] for s in start])
    return x, y

x, y = get_batch('train', 4, 8)
# this is 4 * 8 samples 
print(x.shape, y.shape)
print(x)
print(y)

torch.Size([4, 8]) torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


In [7]:
@torch.no_grad()
def estimate_loss(batch_size, block_size, eval_iters=100):
    out = {}
    m.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = get_batch(split, batch_size, block_size)
            logits, loss = m(x, y)
            losses[k] = loss
        out[f'{split}_loss'] = losses.mean().item()
    m.train()
    return out 

In [8]:
# training loop
def train(model, optimizer, batch_size, block_size, num_steps, eval_iters=100):
    for step in range(num_steps):
        x, y = get_batch('train', batch_size, block_size)
        optimizer.zero_grad(set_to_none=True)
        logits, loss = model(x, y)
        loss.backward()
        optimizer.step()
        
        if step % 500 == 0:
            print(estimate_loss(batch_size, block_size, eval_iters))

    return loss

# Transformer model

In [9]:
class FeedForward(torch.nn.Module):
    def __init__(self, in_size, dropout):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(in_size, in_size * 4),
            torch.nn.ReLU(),
            torch.nn.Linear(in_size * 4, in_size),
            # add dropout
            torch.nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [10]:
class Head(torch.nn.Module):
    def __init__(self, emb_size, head_size, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.k = torch.nn.Linear(emb_size, head_size)
        self.q = torch.nn.Linear(emb_size, head_size)
        self.v = torch.nn.Linear(emb_size, head_size)
        # add dropout
        self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, x):
        # x: B x T x E
        B, T, E = x.shape
        # B x T x L
        k = self.k(x)
        # B x T x L
        q = self.q(x)
        # B x T x L
        v = self.v(x)
        # B x T x T
        att = k @ q.transpose(-2, -1)
        att = att / (k.shape[-1] ** 0.5)
        # B x T x T
        mask = torch.tril(torch.ones(T, T))
        # B x T x T
        masked_att = att.masked_fill(mask == 0, float('-inf'))
        # B x T x T
        masked_att = torch.nn.functional.softmax(masked_att, dim=-1)
        masked_att = self.dropout(masked_att)
        # B x T x L
        att_bow = masked_att @ v
        return att_bow

In [11]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, in_size, head_size, num_heads, dropout):
        super().__init__()
        self.heads = torch.nn.ModuleList([
            Head(in_size, head_size, dropout) for _ in range(num_heads)
        ])
        self.proj = torch.nn.Linear(head_size*num_heads, in_size)
        # add dropout
        self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, x):
        # x: B x T x E
        B, T, E = x.shape
        # B x T x n x H
        att_bows = torch.stack([head(x) for head in self.heads], dim=2)
        # B x T x n*H
        att_bows = att_bows.view(B, T, -1)
        # B x T x E
        out = self.proj(att_bows)
        out = self.dropout(out)
        return out

In [12]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, in_size, head_size, num_heads, dropout):
        super().__init__()
        self.multihead = MultiHeadAttention(in_size, head_size, num_heads, dropout)
        self.feedforward = FeedForward(in_size, dropout)
        self.layernorm1 = torch.nn.LayerNorm(in_size)
        self.layernorm2 = torch.nn.LayerNorm(in_size)
    
    def forward(self, x):
        # x: B x T x in
        # B x T x in
        x = self.layernorm1(x)
        # B x T x in_size
        x = x + self.multihead(x)
        # B x T x in_size
        x = self.layernorm2(x)
        # B x T x in_size
        x = x + self.feedforward(x)
        return x

In [13]:
class AttentionModel(torch.nn.Module):
    def __init__(self, block_size, vocab_size, num_heads, hidden_size, num_blocks, dropout):
        super().__init__()
        self.block_size = block_size
        self.token_emb = torch.nn.Embedding(vocab_size, hidden_size)
        self.position_emb = torch.nn.Embedding(block_size, hidden_size)
        self.blocks = torch.nn.ModuleList([
                ResidualBlock(hidden_size, hidden_size // num_heads, num_heads, dropout) for _ in range(num_blocks)
            ] + [
                torch.nn.LayerNorm(hidden_size)
            ]
        )
        self.out = torch.nn.Linear(hidden_size, vocab_size)
        
    def forward(self, idx, target):
        # idx: B x T
        B, T = idx.shape
        # B x T x E (hidden_size)
        token_emb = self.token_emb(idx)  
        # position_idx: block_size
        position_idx = torch.arange(T)
        # position_emb: block_size x emb_size
        position_emb = self.position_emb(position_idx)
        # B x T x E
        x = token_emb + position_emb
        # B x T x E
        for block in self.blocks:
            x = block(x)   
        # B x T x C
        logits = self.out(x)
        
        if target is not None:
            # target: B x block_size
            target = target.view(-1)
            # B*block_size x vocab_size
            logits = logits.view(B*T, -1)
            loss = torch.nn.functional.cross_entropy(logits, target)
            # B x block_size x vocab_size
            logits = logits.view(B, T, -1)
        else:
            loss = None
        return logits, loss 

    def generate(self, idx, max_new_tokens):
        # idx: B x T
        for _ in range(max_new_tokens):
            # B x T x C
            logits, _ = self(idx[:, -self.block_size:], None)
            # only interested in prediction from last token
            # B x 1 x C
            logits = logits[:, -1, :]
            # B x 1 x C
            probs = torch.nn.functional.softmax(logits, dim=-1)
            # B x 1
            n = torch.multinomial(probs, num_samples=1)
            # B x T+1
            idx = torch.cat([idx, n], dim=1)
        return idx

In [16]:
block_size = 16
m = AttentionModel(
    block_size=block_size, 
    vocab_size=len(chars), 
    num_heads=3,
    hidden_size=32,
    num_blocks=2,
    dropout=0.2,
)
optimizer = torch.optim.AdamW(m.parameters(), lr=0.0003)
train(
    model=m, 
    optimizer=optimizer, 
    batch_size=16,  
    block_size=block_size,
    num_steps=1000,
    eval_iters=50,
)

{'train_loss': 4.293725967407227, 'val_loss': 4.292873382568359}
{'train_loss': 2.8864548206329346, 'val_loss': 2.9066147804260254}


tensor(2.6806, grad_fn=<NllLossBackward0>)