In [24]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import torch.nn as nn
import torch.nn.functional as F

In [2]:
with open('/kaggle/input/popular-names/names.txt','r') as file:
    data = file.read().split('\n')

In [3]:
data[:5]

['emma', 'olivia', 'ava', 'isabella', 'sophia']

In [4]:
train_data = data[:int(len(data)*0.8)]
val_data = data[int(len(data)*0.8):int(len(data)*0.9)]
test_data = data[int(len(data)*0.9):int(len(data))]

In [5]:
print(len(train_data) , len(val_data) , len(test_data))

25626 3203 3204


In [6]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.enc = nn.Embedding(27,9)
        self.layer1 = nn.Linear(9*2,9*3)
        self.layer2 = nn.Linear(9*3,9*4)
        self.layer3 = nn.Linear(9*4,27)
        self.probs = nn.Softmax(dim=1)
    def forward(self, x):
        x = self.enc(x)
        x = x.view(1,18)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.probs(x)
        return x

In [7]:
trigram = Model()

In [8]:
trigram.eval()

Model(
  (enc): Embedding(27, 9)
  (layer1): Linear(in_features=18, out_features=27, bias=True)
  (layer2): Linear(in_features=27, out_features=36, bias=True)
  (layer3): Linear(in_features=36, out_features=27, bias=True)
  (probs): Softmax(dim=1)
)

In [9]:
outputs = trigram(torch.tensor([1,2]))


In [10]:
outputs.shape

torch.Size([1, 27])

In [11]:
from tqdm import tqdm

In [12]:
mp = {'a': 1, 'A': 1, 'b': 2, 'B': 2, 'c': 3, 'C': 3, 'd': 4, 'D': 4, 'e': 5, 'E': 5, 'f': 6, 'F': 6, 'g': 7, 'G': 7, 'h': 8, 'H': 8, 'i': 9, 'I': 9, 'j': 10, 'J': 10, 'k': 11, 'K': 11, 'l': 12, 'L': 12, 'm': 13, 'M': 13, 'n': 14, 'N': 14, 'o': 15, 'O': 15, 'p': 16, 'P': 16, 'q': 17, 'Q': 17, 'r': 18, 'R': 18, 's': 19, 'S': 19, 't': 20, 'T': 20, 'u': 21, 'U': 21, 'v': 22, 'V': 22, 'w': 23, 'W': 23, 'x': 24, 'X': 24, 'y': 25, 'Y': 25, 'z': 26, 'Z': 26}


In [13]:
epochs = 10

In [14]:
import torch.optim as optim

In [27]:
optimizer = optim.Adam(trigram.parameters(), lr=10)

In [16]:
criterion = nn.MSELoss()

In [19]:
train_loss = []
for itr in tqdm(range(0,epochs)):
    net_loss = 0
    for word in train_data:
        for i in range(0,len(word)-2):
            optimizer.zero_grad()
            input = torch.tensor([mp[word[i]],mp[word[i+1]]])
            ground = F.one_hot(torch.tensor(mp[word[i+2]]),27).view(1, 27)
            output = trigram(input)
            loss = criterion(output.float(),ground.float())
            loss.backward()
            optimizer.step()
            net_loss+=loss
    print(net_loss)
    train_loss.append(net_loss)

 10%|█         | 1/10 [02:03<18:31, 123.52s/it]

tensor(3353.9563, grad_fn=<AddBackward0>)


 20%|██        | 2/10 [03:57<15:44, 118.08s/it]

tensor(3314.3972, grad_fn=<AddBackward0>)


 30%|███       | 3/10 [05:51<13:33, 116.28s/it]

tensor(3306.8186, grad_fn=<AddBackward0>)


 40%|████      | 4/10 [07:47<11:35, 115.87s/it]

tensor(3302.9299, grad_fn=<AddBackward0>)


 50%|█████     | 5/10 [09:42<09:38, 115.69s/it]

tensor(3300.3359, grad_fn=<AddBackward0>)


 60%|██████    | 6/10 [11:37<07:41, 115.43s/it]

tensor(3298.6614, grad_fn=<AddBackward0>)


 70%|███████   | 7/10 [13:33<05:47, 115.79s/it]

tensor(3297.9810, grad_fn=<AddBackward0>)


 80%|████████  | 8/10 [15:30<03:52, 116.07s/it]

tensor(3297.6467, grad_fn=<AddBackward0>)


 90%|█████████ | 9/10 [17:25<01:55, 115.78s/it]

tensor(3297.3098, grad_fn=<AddBackward0>)


100%|██████████| 10/10 [19:21<00:00, 116.14s/it]

tensor(3297.1846, grad_fn=<AddBackward0>)





In [26]:
print(train_loss)

[tensor(3353.9563, grad_fn=<AddBackward0>), tensor(3314.3972, grad_fn=<AddBackward0>), tensor(3306.8186, grad_fn=<AddBackward0>), tensor(3302.9299, grad_fn=<AddBackward0>), tensor(3300.3359, grad_fn=<AddBackward0>), tensor(3298.6614, grad_fn=<AddBackward0>), tensor(3297.9810, grad_fn=<AddBackward0>), tensor(3297.6467, grad_fn=<AddBackward0>), tensor(3297.3098, grad_fn=<AddBackward0>), tensor(3297.1846, grad_fn=<AddBackward0>)]


In [None]:
plt.plot(train_loss)
plt.show()