<a href="https://colab.research.google.com/github/unfortunate-code/Neural-Network-For-Taking-Derivatives/blob/main/Derivatives_FCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
%cd drive/MyDrive

/content/drive/MyDrive


In [None]:
import itertools
import random
import re
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [None]:
def simplify_equation(line):
  x, y = line.strip().split('=')
  matches = re.match('d\((.*)\)/d(.*)', x)
  return matches.group(1), matches.group(2), y

In [None]:
with open('derivatives-takehome/train.txt', 'r') as f:
  data = [simplify_equation(line) for line in f.readlines()]

In [None]:
def tokenize_equations(eqn, var, var_map):
  if not var_map:
    var_map[var] = 'var0'
    var_index = 1
  else:
    var_index = int(sorted(list(var_map.values()))[-1][-1]) + 1
  curr = ''
  tokens = []
  for i in range(len(eqn)):
    if 'a' <= eqn[i] <='z' or 'A' <= eqn[i] <= 'Z':
      curr += eqn[i]
    else:
      if curr:
        if len(curr) == 1:
          if curr in var_map:
            tokens.append(var_map[curr])
          else:
            var_map[curr] = 'var' + str(var_index)
            var_index += 1
            tokens.append(var_map[curr])
        else:
          tokens.append(curr)
        curr = ''
      tokens.append(eqn[i])
  return tokens, var_map

In [None]:
tokenized_data = []
for x, v, y in data:
  x_tokens, var_map = tokenize_equations(x, v, {})
  y_tokens, var_map = tokenize_equations(y, v, var_map)
  tokenized_data.append((x_tokens, y_tokens, {v: k for k, v in var_map.items()}))

In [None]:
input_vocabulary = set(itertools.chain.from_iterable([x for x, y, _ in tokenized_data]))
output_vocabulary = set(itertools.chain.from_iterable([y for x, y, _ in tokenized_data]))

In [None]:
variables = set(itertools.chain.from_iterable([z.values() for _, _, z in tokenized_data]))

In [None]:
print(variables)
print(input_vocabulary)
print(output_vocabulary)
print(len(variables))
print(len(input_vocabulary))
print(len(output_vocabulary))

{'x', 'a', 'm', 'k', 'u', 'y', 'n', 'c', 's', 'z', 'e', 'i', 'w', 'p', 't', 'b', 'r', 'v', 'o'}
{'cos', 'var0', ')', '1', '6', 'sin', '^', '4', '(', '+', '-', 'exp', '2', '9', '*', '3', '8', '0', '7', '5', 'var1'}
{'cos', 'var0', ')', '1', '6', 'sin', '^', '4', '(', '+', '-', 'exp', '2', '9', '*', '3', '8', '0', '7', '5', 'var1'}
19
21
21


In [None]:
pad_token = '<pad>'
input_vocabulary.add(pad_token)
output_vocabulary.add(pad_token)
input_token_to_index = {}
input_index_to_token = {}
index = 0
for token in input_vocabulary:
  input_token_to_index[token] = index
  input_index_to_token[index] = token
  index += 1
output_token_to_index = {}
output_index_to_token = {}
index = 0
for token in output_vocabulary:
  output_token_to_index[token] = index
  output_index_to_token[index] = token
  index += 1

In [None]:
all_indices = set(range(1000000))
train_indices = set(random.sample(all_indices, 9 * len(all_indices) // 10))
all_indices -= train_indices
val_indices = set(random.sample(all_indices, len(all_indices) // 2))
all_indices -= val_indices
test_indices = all_indices

In [None]:
print(len(train_indices))
print(len(test_indices))
print(len(val_indices))

900000
50000
50000


In [None]:
train_data = [tokenized_data[i] for i in train_indices]
test_data = [tokenized_data[i] for i in test_indices]
val_data = [tokenized_data[i] for i in val_indices]

In [None]:
class DerivativesDataset(Dataset):
  def __init__(self, data):
    self.data = data
  
  def __len__(self):
    return len(self.data)
  
  def __getitem__(self, index):
    x, y, map = self.data[index]
    x = torch.tensor([input_token_to_index[token] for token in x] + (30 - len(x)) * [input_token_to_index[pad_token]])
    y = torch.tensor([output_token_to_index[token] for token in y] + (30 - len(x)) * [output_token_to_index[pad_token]])
    return x, y, map

In [None]:
train_dataset = DerivativesDataset(train_data)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True, num_workers=2)

In [None]:
class DerivativesModel(nn.Module):
  def __init__(self, input_size, sequence_length, output_size, hidden_size):
    super(DerivativesModel, self).__init__()
    self.embedding_layer = nn.Embedding(input_size, hidden_size)
    self.linear1 = nn.Linear(hidden_size * sequence_length, hidden_size * sequence_length * 2)
    self.linear2 = nn.Linear(hidden_size * sequence_length * 2, hidden_size * sequence_length * 4)
    self.linear3 = nn.Linear(hidden_size * sequence_length * 4, hidden_size * sequence_length * 2)
    self.linear3 = nn.Linear(hidden_size * sequence_length * 2, output_size * sequence_length)
    torch.nn.init.kaiming_normal_(self.linear1.weight, mode='fan_in')
    torch.nn.init.kaiming_normal_(self.linear2.weight, mode='fan_in')
    torch.nn.init.kaiming_normal_(self.linear3.weight, mode='fan_in')
  
  def forward(self, x):
    embedding = self.embedding_layer(x)
    embedding = embedding.view(x.shape[0], -1)
    output = self.linear1(embedding)
    output = torch.nn.functional.relu(output)
    output = self.linear2(output)
    output = torch.nn.functional.relu(output)
    output = self.linear3(output)
    output = torch.nn.functional.relu(output)
    output = output.view(-1, output_size)
    return output

In [None]:
hidden_size = 22
model = DerivativesModel(len(input_vocabulary), 30, len(output_vocabulary), hidden_size)
model = model.train().cuda()

RuntimeError: ignored

In [None]:
learning_rate = 0.001
epochs = 100
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [None]:
for epoch in range(epochs):
  print('Epoch', epoch)
  total_loss = []
  for x, y, map in tqdm(train_dataloader):
    x = x.cuda()
    y = y.cuda()
    optimizer.zero_grad()
    output = model(x)
    y = y.view(-1)
    loss = criterion(output, y)
    loss.backward()
    optimizer.step()
  print(sum(total_loss) / len(total_loss))

Epoch 0


100%|██████████| 220/220 [00:32<00:00,  6.72it/s]


3.1036582074718932
Epoch 1


100%|██████████| 220/220 [00:32<00:00,  6.83it/s]


1.621758813974798
Epoch 2


100%|██████████| 220/220 [00:32<00:00,  6.75it/s]


1.2493927411094148
Epoch 3


100%|██████████| 220/220 [00:32<00:00,  6.70it/s]


1.1204184391786174
Epoch 4


100%|██████████| 220/220 [00:32<00:00,  6.71it/s]


1.0539479894783867
Epoch 5


100%|██████████| 220/220 [00:31<00:00,  6.97it/s]


1.0121592087571019
Epoch 6


100%|██████████| 220/220 [00:32<00:00,  6.82it/s]


0.9796748103784217
Epoch 7


100%|██████████| 220/220 [00:32<00:00,  6.86it/s]


0.9513234145719037
Epoch 8


100%|██████████| 220/220 [00:31<00:00,  6.89it/s]


0.9261528781254685
Epoch 9


100%|██████████| 220/220 [00:32<00:00,  6.85it/s]


0.9034015956855509
Epoch 10


100%|██████████| 220/220 [00:31<00:00,  6.90it/s]


0.8817948346296626
Epoch 11


100%|██████████| 220/220 [00:32<00:00,  6.81it/s]


0.8608476265123294
Epoch 12


100%|██████████| 220/220 [00:32<00:00,  6.83it/s]


0.8425291693213606
Epoch 13


100%|██████████| 220/220 [00:32<00:00,  6.86it/s]


0.8249855482004462
Epoch 14


100%|██████████| 220/220 [00:32<00:00,  6.87it/s]


0.8089004338804071
Epoch 15


100%|██████████| 220/220 [00:31<00:00,  6.93it/s]


0.7930092880357137
Epoch 16


100%|██████████| 220/220 [00:32<00:00,  6.77it/s]


0.7781281413094852
Epoch 17


100%|██████████| 220/220 [00:31<00:00,  6.90it/s]


0.7644618015325758
Epoch 18


100%|██████████| 220/220 [00:31<00:00,  6.91it/s]


0.7520968951460134
Epoch 19


100%|██████████| 220/220 [00:32<00:00,  6.87it/s]


0.7397549823702624
Epoch 20


100%|██████████| 220/220 [00:32<00:00,  6.70it/s]


0.7274688455546575
Epoch 21


100%|██████████| 220/220 [00:31<00:00,  6.98it/s]


0.7174730843448791
Epoch 22


100%|██████████| 220/220 [00:31<00:00,  6.91it/s]


0.7064839947232676
Epoch 23


100%|██████████| 220/220 [00:33<00:00,  6.61it/s]


0.6968718569260091
Epoch 24


100%|██████████| 220/220 [00:32<00:00,  6.85it/s]


0.6870039677652982
Epoch 25


100%|██████████| 220/220 [00:31<00:00,  6.93it/s]


0.6784106323380343
Epoch 26


100%|██████████| 220/220 [00:31<00:00,  6.89it/s]


0.6695045704642931
Epoch 27


100%|██████████| 220/220 [00:31<00:00,  6.99it/s]


0.6650787089470653
Epoch 28


100%|██████████| 220/220 [00:31<00:00,  7.00it/s]


0.6533100470069355
Epoch 29


100%|██████████| 220/220 [00:31<00:00,  6.93it/s]


0.6468722471161195
Epoch 30


100%|██████████| 220/220 [00:33<00:00,  6.57it/s]


0.6407291022885431
Epoch 31


100%|██████████| 220/220 [00:31<00:00,  6.93it/s]


0.6349707789724112
Epoch 32


100%|██████████| 220/220 [00:31<00:00,  6.89it/s]


0.6290663170943698
Epoch 33


100%|██████████| 220/220 [00:31<00:00,  6.92it/s]


0.6236884187511419
Epoch 34


100%|██████████| 220/220 [00:31<00:00,  6.93it/s]


0.6187976831395019
Epoch 35


100%|██████████| 220/220 [00:32<00:00,  6.80it/s]


0.6160859399079596
Epoch 36


100%|██████████| 220/220 [00:31<00:00,  6.91it/s]


0.6072134966311115
Epoch 37


100%|██████████| 220/220 [00:32<00:00,  6.83it/s]


0.6082370617421925
Epoch 38


100%|██████████| 220/220 [00:31<00:00,  6.89it/s]


0.5982796078878304
Epoch 39


100%|██████████| 220/220 [00:31<00:00,  6.89it/s]


0.5948435699101537
Epoch 40


100%|██████████| 220/220 [00:31<00:00,  6.89it/s]


0.5920177758370917
Epoch 41


100%|██████████| 220/220 [00:31<00:00,  6.93it/s]


0.5925185445405703
Epoch 42


100%|██████████| 220/220 [00:32<00:00,  6.84it/s]


0.5832875177905124
Epoch 43


100%|██████████| 220/220 [00:32<00:00,  6.87it/s]


0.5904986146557075
Epoch 44


100%|██████████| 220/220 [00:31<00:00,  6.94it/s]


0.5759138829220006
Epoch 45


100%|██████████| 220/220 [00:32<00:00,  6.74it/s]


0.5729091276004109
Epoch 46


100%|██████████| 220/220 [00:32<00:00,  6.69it/s]


0.5762605244664335
Epoch 47


100%|██████████| 220/220 [00:31<00:00,  6.95it/s]


0.5671889283969718
Epoch 48


100%|██████████| 220/220 [00:33<00:00,  6.51it/s]


0.5641796905747665
Epoch 49


100%|██████████| 220/220 [00:35<00:00,  6.12it/s]


0.572619176617453
Epoch 50


100%|██████████| 220/220 [00:31<00:00,  6.95it/s]


0.5575562215069689
Epoch 51


100%|██████████| 220/220 [00:32<00:00,  6.68it/s]


0.5558616792612899
Epoch 52


100%|██████████| 220/220 [00:31<00:00,  6.91it/s]


0.5681038695682962
Epoch 53


100%|██████████| 220/220 [00:32<00:00,  6.81it/s]


0.5497704567888411
Epoch 54


100%|██████████| 220/220 [00:32<00:00,  6.85it/s]


0.547678952586026
Epoch 55


100%|██████████| 220/220 [00:33<00:00,  6.54it/s]


0.5565217508889374
Epoch 56


100%|██████████| 220/220 [00:31<00:00,  6.96it/s]


0.5422346810710126
Epoch 57


100%|██████████| 220/220 [00:31<00:00,  6.89it/s]


0.540475785379268
Epoch 58


100%|██████████| 220/220 [00:31<00:00,  7.04it/s]


0.5570766386996594
Epoch 59


100%|██████████| 220/220 [00:31<00:00,  6.95it/s]


0.5365218950100042
Epoch 60


100%|██████████| 220/220 [00:31<00:00,  6.94it/s]


0.5335053039554466
Epoch 61


100%|██████████| 220/220 [00:31<00:00,  7.02it/s]


0.5319077965442932
Epoch 62


100%|██████████| 220/220 [00:31<00:00,  7.07it/s]


0.5535540826173039
Epoch 63


100%|██████████| 220/220 [00:31<00:00,  6.94it/s]


0.527729535915999
Epoch 64


100%|██████████| 220/220 [00:31<00:00,  6.97it/s]


0.5262164419907476
Epoch 65


100%|██████████| 220/220 [00:31<00:00,  7.01it/s]


0.525257394926983
Epoch 66


100%|██████████| 220/220 [00:31<00:00,  7.00it/s]


0.5485210125434703
Epoch 67


100%|██████████| 220/220 [00:31<00:00,  6.93it/s]


0.5210686422939542
Epoch 68


100%|██████████| 220/220 [00:31<00:00,  7.01it/s]


0.5189394115931004
Epoch 69


100%|██████████| 220/220 [00:31<00:00,  7.02it/s]


0.5172840538084187
Epoch 70


100%|██████████| 220/220 [00:31<00:00,  7.02it/s]


0.5311871129964348
Epoch 71


100%|██████████| 220/220 [00:31<00:00,  7.05it/s]


0.5137975048624299
Epoch 72


100%|██████████| 220/220 [00:31<00:00,  6.97it/s]


0.5122689681292902
Epoch 73


100%|██████████| 220/220 [00:31<00:00,  6.97it/s]


0.5261656560093886
Epoch 74


100%|██████████| 220/220 [00:31<00:00,  6.97it/s]


0.510649552283388
Epoch 75


100%|██████████| 220/220 [00:31<00:00,  6.90it/s]


0.5071883235847757
Epoch 76


100%|██████████| 220/220 [00:32<00:00,  6.86it/s]


0.5056358437199066
Epoch 77


100%|██████████| 220/220 [00:33<00:00,  6.66it/s]


0.5189875233768215
Epoch 78


100%|██████████| 220/220 [00:31<00:00,  7.04it/s]


0.5260784021874149
Epoch 79


100%|██████████| 220/220 [00:31<00:00,  6.99it/s]


0.500765253839794
Epoch 80


100%|██████████| 220/220 [00:31<00:00,  7.00it/s]


0.49923025763353274
Epoch 81


100%|██████████| 220/220 [00:32<00:00,  6.85it/s]


0.49796800169172467
Epoch 82


100%|██████████| 220/220 [00:31<00:00,  7.05it/s]


0.5199347403506318
Epoch 83


100%|██████████| 220/220 [00:30<00:00,  7.10it/s]


0.5287780083155119
Epoch 84


100%|██████████| 220/220 [00:32<00:00,  6.77it/s]


0.49468018680119946
Epoch 85


100%|██████████| 220/220 [00:31<00:00,  7.03it/s]


0.492874194732717
Epoch 86


100%|██████████| 220/220 [00:33<00:00,  6.62it/s]


0.4917874924152569
Epoch 87


100%|██████████| 220/220 [00:31<00:00,  7.01it/s]


0.5095995657110928
Epoch 88


100%|██████████| 220/220 [00:31<00:00,  6.96it/s]


0.4892800115774916
Epoch 89


100%|██████████| 220/220 [00:31<00:00,  7.05it/s]


0.4875493700721211
Epoch 90


100%|██████████| 220/220 [00:31<00:00,  6.98it/s]


0.48630655420723784
Epoch 91


100%|██████████| 220/220 [00:31<00:00,  7.02it/s]


0.4857621076616949
Epoch 92


100%|██████████| 220/220 [00:31<00:00,  7.05it/s]


0.4839972259541873
Epoch 93


100%|██████████| 220/220 [00:31<00:00,  6.95it/s]


0.5399481891704503
Epoch 94


100%|██████████| 220/220 [00:31<00:00,  6.97it/s]


0.4824001806442656
Epoch 95


100%|██████████| 220/220 [00:31<00:00,  6.95it/s]


0.4795270743679696
Epoch 96


100%|██████████| 220/220 [00:30<00:00,  7.19it/s]


0.4780582267821075
Epoch 97


100%|██████████| 220/220 [00:31<00:00,  7.02it/s]


0.47717958994670423
Epoch 98


100%|██████████| 220/220 [00:31<00:00,  7.05it/s]


0.4760527797518999
Epoch 99


100%|██████████| 220/220 [00:31<00:00,  7.06it/s]

0.4751981512408063





In [None]:
torch.save(model, 'fcnn-model.model')

In [None]:
def get_eqn(y):
  res = []
  for i in range(len(y)):
    if y[i] == output_token_to_index[end_token]: return res
    if isinstance(y[i], int):
      res.append(output_index_to_token[y[i]])
    else:
      res.append(output_index_to_token[y[i].item()])
  return res

In [None]:
def is_match(y, pred):
  for i in range(len(y)):
    if y[i] == output_token_to_index[end_token]: return True
    if i >= len(pred) or y[i] != pred[i]: return False
  return True

In [None]:
def evaluate(dataloader, cutoff_length, encoder, decoder):
  encoder = encoder.eval()
  decoder = decoder.eval()
  count = correct_count = 0
  with torch.no_grad():
    for x, y, map in tqdm(dataloader):
      count += len(x)
      x, y = x.cuda(), y.cuda()
      _, (enc_hn, enc_cn) = encoder(x)
      for i in range(len(x)):
        prediction = []
        dec_h_i, dec_c_i = enc_hn[:, i, :].contiguous().unsqueeze(1), enc_cn[:, i, :].contiguous().unsqueeze(1)
        decoder_input = torch.IntTensor(1, 1).fill_(output_token_to_index[start_token]).cuda()
        decoded_indices = []
        for j in range(cutoff_length):
          pred, dec_h_i, dec_c_i = decoder(decoder_input, dec_h_i, dec_c_i)
          _, topi = pred.topk(1)
          topi = topi.item()
          if topi == output_token_to_index[end_token]:
            break
          else:
            decoded_indices.append(topi)
          decoder_input = torch.IntTensor(1, 1).fill_(topi).cuda()
        if is_match(y[i], decoded_indices): correct_count += 1
        # print(y[i])
        # print(decoded_indices)
        # print(get_eqn(x[i]))
        # print(get_eqn(y[i]))
        # print(get_eqn(decoded_indices))
        # return
  return correct_count / count

In [None]:
val_dataset = DerivativesDataset(val_data)
val_dataloader = DataLoader(val_dataset, batch_size=4096, shuffle=False, pin_memory=True, num_workers=2, collate_fn=PadSequence())
evaluate(val_dataloader, 40, encoder, decoder)

100%|██████████| 13/13 [05:05<00:00, 23.52s/it]


0.01102

In [None]:
get_eqn(val_dataset[0][1])

['4', '1', '6', 'exp', '^', '(', '1', '0', 'var1', ')', '*', 'var0', '^', '3']