In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
%matplotlib inline

In [None]:
# Hyperparameters
batch_size = 64
in_embed_dim = 100
out_embed_dim = 200
num_heads = 4
num_layers = 3
learning_rate = 3e-4
max_iters= 10000
eval_interval = 1000
eval_iters = 100

In [None]:
torch.manual_seed(19)
device='cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
with open('cmudict-0.7b', 'rb') as f:
    data = f.readlines()
data = data[126:-5]
filtered_data = []
for i in range(len(data)):
    if i == 35606: continue
    filtered_data.append(data[i].decode())
with open('cmudict-0.7b.symbols') as f:
    symbols = f.read().splitlines()

In [None]:
spellings = []
pronounciations = []
for d in filtered_data:
    split = d.strip("\r\n").split()
    spellings.append(split[0])
    pronounciations.append(['#'] + split[1:] + ['#'])

In [None]:
# Remove words with numbers in them
for i, s in enumerate(spellings):
    if '(' not in s:
        if '1' in s or '2' in s or '3' in s or '4' in s or '5' in s or '6' in s or '7' in s or '8' in s or '9' in s or '0' in s:
            spellings.pop(i)
            pronounciations.pop(i)

In [None]:
# Don't track copies separately
for i, s in enumerate(spellings):
    spellings[i] = s.strip("(1234567890)")

In [None]:
plt.hist([len(s) for s in spellings])

In [None]:
plt.hist([len(p) for p in pronounciations])

In [None]:
in_tokens = sorted(list(set(''.join(spellings))))
num_tokens = len(in_tokens) + 1
in_stoi = {s:i for i, s in enumerate(in_tokens, 1)}
in_stoi['#'] = 0
in_itos = {i:s for s, i in in_stoi.items()}

num_symbols = len(symbols) + 1
out_stoi = {s:i for i, s in enumerate(symbols, 1)}
out_stoi['#'] = 0
out_itos = {i:s for s, i in out_stoi.items()}

In [None]:
MAX_LENGTH = 20
spellings_padded = []
for s in spellings:
    if len(s) > MAX_LENGTH:
        spellings_padded.append([in_stoi[c] for c in s[:MAX_LENGTH]])
    else:
        a = [in_stoi[c] for c in s]
        for _ in range(MAX_LENGTH - len(a)):
            a.append(0)
        spellings_padded.append(a)

pronounciations_padded = []
for p in pronounciations:
    if len(p) > MAX_LENGTH + 1:
        pronounciations_padded.append([out_stoi[s] for s in p[:MAX_LENGTH+1]])
    else:
        a = [out_stoi[s] for s in p]
        for _ in range(MAX_LENGTH + 1 - len(a)):
            a.append(0)
        pronounciations_padded.append(a)

In [None]:
spellings_tensor = torch.tensor(spellings_padded, device=device)
pronounciations_tensor = torch.tensor(pronounciations_padded, device=device)
# Also maintaining the test set in terms of letters and symbols to cee generation at the end
spellings_train, spellings_test, pronounciations_train, pronounciations_test, _, test_words, _, test_pronounce_symbols = train_test_split(spellings_tensor, pronounciations_tensor, spellings, pronounciations, test_size=0.1, random_state=19)
spellings_train, spellings_val, pronounciations_train, pronounciations_val = train_test_split(spellings_train, pronounciations_train, test_size=0.15, random_state=19)

In [None]:
def get_batch(mode='train'):

    if mode == 'train':
        enc_data = spellings_train
        dec_data = pronounciations_train
    elif mode == 'val':
        enc_data = spellings_val
        dec_data = pronounciations_val
    elif mode == 'test':
        enc_data = spellings_test
        dec_data = pronounciations_test
    else:
        raise ValueError("Invalid Mode")

    idxs = torch.randint(enc_data.shape[0], (batch_size,))
    x_enc = enc_data[idxs]
    x_dec = dec_data[idxs][:, :MAX_LENGTH]
    y = dec_data[idxs][:, 1:]

    return x_enc, x_dec, y

In [None]:
class EncoderHead(nn.Module):

    def __init__(self, head_dim):
        super().__init__()

        self.q = nn.Linear(in_embed_dim, head_dim, bias=False)
        self.k = nn.Linear(in_embed_dim, head_dim, bias=False)
        self.v = nn.Linear(in_embed_dim, head_dim, bias=False)

    def forward(self, x):
        # print('X shape:', x.shape)
        q = self.q(x)
        k = self.k(x)
        qkt = q @ k.transpose(-1, -2) / (k.shape[-1]**0.5)
        qkt_softmax = F.softmax(qkt, dim=-1)
        v = self.v(x)
        return qkt_softmax @ v

In [None]:
class DecoderHead(nn.Module):

    def __init__(self, head_dim):
        super().__init__()

        self.q = nn.Linear(out_embed_dim, head_dim, bias=False)
        self.k = nn.Linear(out_embed_dim, head_dim, bias=False)
        self.v = nn.Linear(out_embed_dim, head_dim, bias=False)
        self.register_buffer('mask', torch.tril(torch.ones(MAX_LENGTH, MAX_LENGTH)))

    def forward(self, x):
        q = self.q(x)
        k = self.k(x)
        # print('q:', q.shape)
        # print('k:', k.shape)
        _, time_dim, head_dim = k.shape
        # print('head_dim:', head_dim)
        # print('kt:', k.transpose(-1, -2).shape)
        qkt = q @ k.transpose(-1, -2) / (head_dim**0.5)
        # print('qkt:', qkt.shape)
        qkt = qkt.masked_fill(self.mask[:time_dim, :time_dim] == 0, float('-inf'))
        qkt_softmax = F.softmax(qkt, dim=-1)
        v = self.v(x)
        return qkt_softmax @ v

In [None]:
class CrossAttentionHead(nn.Module):

    def __init__(self, head_dim):
        super().__init__()

        self.q = nn.Linear(out_embed_dim, head_dim, bias=False)
        self.k = nn.Linear(in_embed_dim, head_dim, bias=False)
        self.v = nn.Linear(in_embed_dim, head_dim, bias=False)

    def forward(self, x_enc, x_dec):
        q = self.q(x_dec)
        k = self.k(x_enc)
        qkt = q @ k.transpose(-1, -2) / (k.shape[-1]**0.5)
        qkt_softmax = F.softmax(qkt, dim=-1)
        v = self.v(x_enc)
        return qkt_softmax @ v

In [None]:
class MultiHead(nn.Module):

    def __init__(self, num_heads, head_size, is_encoder):
        super().__init__()

        if is_encoder:
            self.heads = nn.ModuleList([EncoderHead(head_size) for _ in range(num_heads)])
            self.proj = nn.Linear(head_size * num_heads, in_embed_dim)
        else:
            self.heads = nn.ModuleList([DecoderHead(head_size) for _ in range(num_heads)])
            self.proj = nn.Linear(head_size * num_heads, out_embed_dim)

    def forward(self, x):
        # print("MULTIHEAD")
        # os = []
        # for h in self.heads:
        #   o = h(x)
        #   print(o.shape)
        #   os.append(o)
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.proj(out)

In [None]:
class MultiCrossHead(nn.Module):

    def __init__(self, num_heads, head_size):
        super().__init__()

        self.heads = nn.ModuleList([CrossAttentionHead(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, out_embed_dim)

    def forward(self, x_enc, x_dec):
        out = torch.cat([h(x_enc, x_dec) for h in self.heads], dim=-1)
        return self.proj(out)

In [None]:
class FeedForward(nn.Module):

    def __init__(self, embed_dim):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.ReLU(),
            nn.Linear(4 * embed_dim, embed_dim)
        )

    def forward(self, x):
        return self.layers(x)

In [None]:
class EncoderBlock(nn.Module):

    def __init__(self):
        super().__init__()

        head_size = in_embed_dim // num_heads

        self.mha = MultiHead(num_heads, head_size, True)
        self.ln1 = nn.LayerNorm(in_embed_dim)
        self.ffnn = FeedForward(in_embed_dim)
        self.ln2 = nn.LayerNorm(in_embed_dim)

    def forward(self, x):

        mhout = self.mha(x)
        x = self.ln1(x + mhout)
        ffout = self.ffnn(x)
        return self.ln2(x + ffout)

In [None]:
class DecoderBlock(nn.Module):

    def __init__(self):
        super().__init__()

        head_size = out_embed_dim // num_heads

        self.mha = MultiHead(num_heads, head_size, False)
        self.ln1 = nn.LayerNorm(out_embed_dim)
        self.ca = MultiCrossHead(num_heads, head_size)
        self.ln2 = nn.LayerNorm(out_embed_dim)
        self.ffnn = FeedForward(out_embed_dim)
        self.ln3 = nn.LayerNorm(out_embed_dim)

    def forward(self, x_enc, x_dec):
        mhout = self.mha(x_dec)
        x_dec = self.ln1(x_dec + mhout)
        caout = self.ca(x_enc, x_dec)
        x_dec = self.ln2(x_dec + caout)
        ffout = self.ffnn(x_dec)
        return self.ln3(x_dec + ffout)

In [None]:
class Encoder(nn.Module):

    def __init__(self):
        super().__init__()

        self.token_embedding = nn.Embedding(num_tokens, in_embed_dim)
        self.position_embedding = nn.Embedding(MAX_LENGTH, in_embed_dim)
        self.blocks = nn.Sequential(*[EncoderBlock() for _ in range(num_layers)])

    def forward(self, x):
        token_embed = self.token_embedding(x)
        # print(token_embed.shape)
        b = torch.arange(x.shape[1], device=device)
        # print(b.shape)
        pos_embed = self.position_embedding(b)
        x = token_embed + pos_embed
        return self.blocks(x)

In [None]:
class Decoder(nn.Module):

    def __init__(self):
        super().__init__()

        self.token_embedding = nn.Embedding(num_symbols, out_embed_dim)
        self.position_embedding = nn.Embedding(MAX_LENGTH, out_embed_dim)
        # self.blocks = nn.Sequential(*[DecoderBlock() for _ in range(num_layers)])
        self.db1 = DecoderBlock()
        self.db2 = DecoderBlock()
        self.db3 = DecoderBlock()

    def forward(self, x_enc, x_dec):
        # print(x_dec.shape)
        token_embed = self.token_embedding(x_dec)
        # print(token_embed.shape)
        b = torch.arange(x_dec.shape[1], device=device)
        # print(b.shape)
        pos_embed = self.position_embedding(b)
        x_dec = token_embed + pos_embed
        # x_dec = self.token_embedding(x_dec) + self.position_embedding(torch.arange(x_dec.shape[1], device=device))
        x_dec = self.db1(x_enc, x_dec)
        x_dec = self.db2(x_enc, x_dec)
        x_dec = self.db3(x_enc, x_dec)
        return x_dec

In [None]:
class Transformer(nn.Module):

    def __init__(self):
        super().__init__()

        self.encoder = Encoder()
        self.decoder = Decoder()
        self.head = nn.Linear(out_embed_dim, num_symbols)

    def forward(self, spelling, pronounciation, target=None):
        x_enc = self.encoder(spelling)
        x_dec = self.decoder(x_enc, pronounciation)
        logits = self.head(x_dec)

        if target==None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            target = target.reshape(B*T)
            loss = F.cross_entropy(logits, target)

        return logits, loss

    def generate(self, spelling):
        pronounciation = torch.zeros((1, 1), device=device, dtype=torch.int64)
        while(True):
            logits, loss = self(spelling, pronounciation)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_symb = torch.multinomial(probs, num_samples=1)
            pronounciation = torch.cat((pronounciation, next_symb), dim=1)
            if next_symb[0] == 0:
                return pronounciation


In [None]:
model = Transformer().to(device)
print('Number of Parameters:', sum(p.numel() for p in model.parameters()))

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
for iter in range(max_iters):

    xb_enc, xb_dec, yb = get_batch('train')

    logits, loss = model(xb_enc, xb_dec, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if iter % eval_interval == 0 or iter == max_iters - 1:
        # losses = estimate_loss()
        print(f"step {iter}: train loss {loss:.4f}")

In [None]:
torch.save(model, 'model.pt')

In [None]:
# model = torch.load('model.pt')

In [None]:
@torch.no_grad()
def get_loss(mode, model):
    if mode == 'train':
        enc_data = spellings_train
        dec_data = pronounciations_train
    elif mode == 'val':
        enc_data = spellings_val
        dec_data = pronounciations_val
    elif mode == 'test':
        enc_data = spellings_test
        dec_data = pronounciations_test
    else:
        raise ValueError("Invalid Mode")

    x_enc = enc_data
    x_dec = dec_data[:, :MAX_LENGTH]
    y = dec_data[:, 1:]
    _, loss = model(x_enc, x_dec, y)
    return loss

In [None]:
get_loss('val', model)

In [None]:
get_loss('test', model)

In [None]:
def get_pronounciation(word):
    word = word.upper()
    if len(word) > MAX_LENGTH:
        word = word[:MAX_LENGTH]
    else:
        word += '#'*(MAX_LENGTH - len(word))

    word_tensor = torch.tensor([in_stoi[c] for c in word], device=device).reshape(1, len(word))
    pronounciation = model.generate(word_tensor)
    out = [out_itos[i.item()] for i in list(pronounciation[0])]
    return ' '.join(out[1:-1])

In [None]:
def dets(word, pro, pred_pro):
    return f"| {word} | {pro} | {pred_pro} |"

In [None]:
for i in range(10):
    word = test_words[i]
    pro = ' '.join(test_pronounce_symbols[i][1:-1])
    pred_pro = get_pronounciation(word)
    print(dets(word, pro, pred_pro))