In [3]:
import torch
import torch.nn.functional as F

In [4]:
itos = {
    0: 'a', 1: 'b', 2: 'c', 3: 'd', 4: 'e', 5: 'f', 6: 'g', 7: 'h', 8: 'i', 9: 'j',
    10: 'k', 11: 'l', 12: 'm', 13: 'n', 14: 'o', 15: 'p', 16: 'q', 17: 'r', 18: 's', 19: 't',
    20: 'u', 21: 'v', 22: 'w', 23: 'x', 24: 'y', 25: 'z',

    26: 'A', 27: 'B', 28: 'C', 29: 'D', 30: 'E', 31: 'F', 32: 'G', 33: 'H', 34: 'I', 35: 'J',
    36: 'K', 37: 'L', 38: 'M', 39: 'N', 40: 'O', 41: 'P', 42: 'Q', 43: 'R', 44: 'S', 45: 'T',
    46: 'U', 47: 'V', 48: 'W', 49: 'X', 50: 'Y', 51: 'Z',

    52: '0', 53: '1', 54: '2', 55: '3', 56: '4', 57: '5', 58: '6', 59: '7', 60: '8', 61: '9',

    62: '.', 63: ',', 64: ';', 65: ':', 66: '?', 67: '!', 68: "'", 69: '"',
    70: '-', 71: '(', 72: ')', 73: '[', 74: ']', 75: '{', 76: '}',

    77: ' ', 78: '\n', 79: '\t',

    80: '@', 81: '#', 82: '$', 83: '%', 84: '^', 85: '&', 86: '*', 87: '_',
    88: '+', 89: '=', 90: '/', 91: '\\', 92: '|', 93: '~', 94: '`',
    95: '<', 96: '>', 97: '–', 98: '—'
}
stoi = {
    'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9,
    'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19,
    'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25,

    'A': 26, 'B': 27, 'C': 28, 'D': 29, 'E': 30, 'F': 31, 'G': 32, 'H': 33, 'I': 34, 'J': 35,
    'K': 36, 'L': 37, 'M': 38, 'N': 39, 'O': 40, 'P': 41, 'Q': 42, 'R': 43, 'S': 44, 'T': 45,
    'U': 46, 'V': 47, 'W': 48, 'X': 49, 'Y': 50, 'Z': 51,

    '0': 52, '1': 53, '2': 54, '3': 55, '4': 56, '5': 57, '6': 58, '7': 59, '8': 60, '9': 61,

    '.': 62, ',': 63, ';': 64, ':': 65, '?': 66, '!': 67, "'": 68, '"': 69,
    '-': 70, '(': 71, ')': 72, '[': 73, ']': 74, '{': 75, '}': 76,

    ' ': 77, '\n': 78, '\t': 79,

    '@': 80, '#': 81, '$': 82, '%': 83, '^': 84, '&': 85, '*': 86, '_': 87,
    '+': 88, '=': 89, '/': 90, '\\': 91, '|': 92, '~': 93, '`': 94,
    '<': 95, '>': 96, '–': 97, '—': 98
}


In [5]:
len(itos)

99

In [15]:
class config:
    vocab_size = len(itos) 
    n_embd = 2
    n_hidden = 4*n_embd
    n_heads = 1
    n_layers = 1
    c_block_size = 15
    w_block_size = 10
    dropout_ratio = 0.2

In [None]:
# Attention cpu version

class CharAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = nn.Linear(config.n_embd, 3*config.n_embd, bias = False)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias = False)
        self.dropout = nn.Dropout(config.dropout_ratio)
        
    def forward(self, x):
        B, W, c, C = x.shape
        
        qkv = self.attn(x)
        q, k, v = self.qkv.split(config.n_embd, dim = -1)
        q = q.view(B, W, c, config.n_heads, C//config.n_heads).transpose(2, 3)
        k = k.view(B, W, c, config.n_heads, C//config.n_heads).transpose(2, 3)
        v = v.view(B, W, c, config.n_heads, C//config.n_heads).transpose(2, 3)
        out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
        out = out.transpose(2, 3).contiguous().view(B, W, c, C)
        
        out = self.c_proj(out)
        out = self.dropout(out)
        out = x + out             # Residual connection
        out = out[:, :, -1, :]    # B, W, C
        return out



class WordAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = nn.Linear(config.n_embd, 3*config.n_embd, bias = False)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias = False)
        self.dropout = nn.Dropout(config.dropout_ratio)

    def forward(self, x):
        B, W, C = x.shape
        
        qkv = self.attn(x)
        q, k, v = self.qkv.split(config.n_embd, dim = -1)
        q = q.view(B, W, config.n_heads, C//config.n_heads).transpose(1, 2)
        k = k.view(B, W, config.n_heads, C//config.n_heads).transpose(1, 2)
        v = v.view(B, W, config.n_heads, C//config.n_heads).transpose(1, 2)
        out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
        out = out.transpose(1, 2).contiguous().view(B, W, C)
        
        out = self.c_proj(out)
        out = self.dropout(out)
        return out


        
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            self.c_fc = nn.Linear(config.n_embd, config.n_hidden, bias = False)
            self.gelu = nn.GELU()
            self.c_proj = nn.Linear(config.n_hidden, config.n_embd, bias = False)
            self.dropout = nn.Dropout(config.dropout_ratio)
        )

    def forward(self, x):
        x = self.net(x)
        return x



class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.w_attn = WordAttention()
        self.mlp = MLP()
        self.ln_2 = nn.LayerNorm(config.n_embd)

    def forward(self, x):
        x = x + self.w_attn(self.ln_1(out))
        x = x + self.mlp(self.ln_2(out))
        return x


        
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.cte = nn.Embedding(config.vocab_size, config.n_embd)
        self.cpe = nn.Embedding(config.c_block_size, config.n_embd)
        self.wpe = nn.Embedding(config.w_block_size, config.n_embd)
        
        self.c_attn = CharAttention()
        self.h = nn.ModuleList([Block() for _ in range(config.n_layers)])
        self.lm_heads = nn.ModuleList([nn.Linear(config.n_hidden, config.vocab_size) for _ in range(config.c_block_size)])

        
    def forward(self, x):
        c_emb = self.cte(x)
        c_pos_emb = self.cpe(chs)
        x = c_emb + c_pos_emb
        x = self.c_attn(chs)
        pos_emb = self.wpe(torch.arange(x.shape[1], dtype = torch.long, device = config.device))
        x = x + pos_emb
        for block in self.h:
            x = block(x)
        logits = []
        for lm_head in self.lm_heads:
            logits.append(lm_head(x))
        logits = torch.stack(logits, dim = 2)
            
        

In [None]:
# Attention cpu version

class CharSelfAttention(nn.Module):
    def __init__(self):
        self.c_attn = nn.Linear(config.n_embd, 3*config.n_embd, bias = False)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias = False)
        self.dropout = nn.Dropout(config.dropout_ratio)
        
    def forward(self, x):
        B, W, c, C = x.shape
        qkv = self.c_attn(x)
        q, k, v = self.qkv.split(config.n_embd, dim = -1)
        q = q.view(B, W, c, config.n_heads, C//config.n_heads).transpose(2, 3)
        k = k.view(B, W, c, config.n_heads, C//config.n_heads).transpose(2, 3)
        v = v.view(B, W, c, config.n_heads, C//config.n_heads).transpose(2, 3)
        out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
        out = out.transpose(2, 3).contiguous().view(B, W, c, C)
        out = self.c_proj(out)
        out = self.dropout(out)
        return out

class MLP(nn.Module):
    def __init__(self):
        self.c_fc = nn.Linear(config.n_embd, config.n_hidden, bias = False)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(config.n_hidden, config.n_embd, bias = False)
        self.dropout = nn.Dropout(config.dropout_ratio)

    def forward(self, x):
        word = x[-1]
        print("Forward input:", word, word.shape)
        word = self.c_fc(x)
        word = self.gelu(x)
        word = self.c_proj(x)
        x = torch.cat((x[:-1], word), dim = 0)
        print("forward output: ", x, x.shape)
        return x

class Block(nn.Module):
    def __init__(self):
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CharSelfAttention()
        self.mlp = MLP()
        self.ln_2 = nn.LayerNorm(config.n_embd)

    def forward(self, x):
        
        
        
        
class GPT(nn.Module):
    def __init__(self):
        self.cte = nn.Embedding(config.vocab_size, config.n_embd)
        self.cpe = nn.Embedding(config.c_block_size, config.n_embd)
        self.wpe = nn.Embedding(config.w_block_size, config.n_embd)

    def forward(self, x):
        for i, chs in enumerate(x):
            c_emb = self.cte(chs)
            c_pos_emb = self.cpe(chs)
            chs = c_emb + c_pos_emb
            if 'any of the end tokens of a word' == chs[-1]:
                w_pos_emb = self.wpe(i)
                chs[-1] += w_pos_emb
            
        

In [None]:
# Attention GPU version

class CharSelfAttention(nn.Module):
    def __init__(self):
        self.c_attn = nn.Linear(config.n_embd, 3*config.n_embd, bias = False)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias = False)
        
    def forward(self, x):
        B, W, T, C = x.shape
        qkv = self.c_attn(x)
        q, k, v = self.qkv.split(config.n_embd, dim = -1)
        q = q.view(B, W, T, config.n_heads, C//config.n_heads).transpose(2, 3)
        k = k.view(B, W, T, config.n_heads, C//config.n_heads).transpose(2, 3)
        v = v.view(B, W, T, config.n_heads, C//config.n_heads).transpose(2, 3)
        out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
        out = out.transpose(2, 3).contiguous().view(B, W, T, C)
        out = self.c_proj(out)
        return out

class MLP(nn.Module):
    def __init__(self):
        self.c_fc = nn.Linear(config.n_embd, config.n_hidden, bias = False)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(config.n_hidden, config.n_embd, bias = False)

    def forward(self, x):
        word = x[-1]
        print("Forward input:", word, word.shape)
        word = self.c_fc(x)
        word = self.gelu(x)
        word = self.c_proj(x)
        x = torch.cat((x[:-1], word), dim = 0)
        print("forward output: ", x, x.shape)
        return x

class Block(nn.Module):
    def __init__(self):
        self.
        
        
        
class GPT(nn.Module):
    def __init__(self):
        self.cte = nn.Embedding(config.vocab_size, config.n_embd)
        self.cpe = nn.Embedding(config.c_block_size, config.n_embd)
        self.wpe = nn.Embedding(config.w_block_size, config.n_embd)

    def forward(self, x):
        for i, chs in enumerate(x):
            char_emb = self.cte(chs)
            char_pos_emb = self.cpe(chs)
            word_pos_emb = self.wpe(i)
            chs = char_emb + char_pos_emb
            chs[-1] += word_pos_emb
            
        