In [10]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import sys
sys.path.append('../src/')
from tokenizer import Tokenizer
from data_utils import dec2base, base2dec
from models import TransformerEmbedding

In [2]:
class CongruenceDataset(Dataset):
    def __init__(self, max_val, base, skip_nums=None, sample_nums=None, sample_func=None, size=-1):
        self.max_val = max_val
        self.base = base
        
        if skip_nums is not None:
            self.skip_nums = set(skip_nums)
        else:
            self.skip_nums = set([])
            
        if sample_nums is None:
            self.sample_nums = None
        else:
            self.sample_nums = sample_nums
        
        if sample_func is None:
            if sample_nums is None:
                self.sample_func = self.default_sample_func
            else:
                self.sample_func = lambda i: self.sample_nums[i]
        else:
            self.sample_func = sample_func
            
        if size < 0:
            self.size = self.max_val
        else:
            self.size = size
            
    def default_sample_func(self, i):
        while i in self.skip_nums:
            i = np.random.randint(0, self.max_val)
        return i
            
    def __getitem__(self, i):
        return dec2base(self.sample_func(i), base)
    
    def __len__(self):
        if self.sample_nums is not None:
            return len(self.sample_nums)
        else:
            return self.size
    
    

In [19]:
max_val = 2**16
test_size = 5000
base = 10
test_nums = np.random.randint(0, max_val, test_size)

In [4]:
train_dataset = CongruenceDataset(max_val, base, skip_nums = test_nums, size=10000)
test_dataset = CongruenceDataset(max_val, base, sample_nums = test_nums)

In [5]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn = lambda x: list(x))
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn = lambda x: list(x))

In [6]:
t = Tokenizer(base)

In [7]:
for batch in train_loader:
    tokenized = torch.tensor(t.encode(batch))
    print(tokenized)
    print(tokenized.size())
    break

tensor([[ 6,  5,  0,  5, 11],
        [ 1,  1,  4,  0, 11],
        [ 1,  9,  6,  6, 11],
        [ 9,  2,  2,  0, 11],
        [ 3,  2,  0,  7, 11],
        [ 4,  9,  7,  9, 11],
        [ 8,  6,  4,  8, 11],
        [ 8,  4,  6,  9, 11],
        [ 6,  9,  1,  7, 11],
        [ 4,  5, 11, 11, 11],
        [ 3,  6,  6,  8,  3],
        [ 6,  1,  2,  5, 11],
        [ 6,  7,  4,  6, 11],
        [ 6,  0,  1, 11, 11],
        [ 6,  1,  2,  7, 11],
        [ 3,  2,  2,  4, 11]])
torch.Size([16, 5])


In [77]:
embed_dim = 128
n_chunks = int(np.ceil(np.log(2**16) / np.log(base))) + 1

In [96]:
class CongruenceFinder(nn.Module):
    def __init__(self, n_tokens, n_chunks, embed_dim, **transformer_args):
        super(CongruenceFinder, self).__init__()
        self.embedding = TransformerEmbedding(n_tokens, embed_dim, 32, .1, False, False)
        encoder_layer = nn.TransformerEncoderLayer(embed_dim, 8, **transformer_args)
        self.transformer = nn.TransformerEncoder(encoder_layer, 4, norm=nn.LayerNorm(embed_dim))
        self.pred_out = nn.Linear(embed_dim, n_chunks*2)
            
    def forward(self, x):
        x = self.embedding(x)
        encoded = self.transformer(x)
        encoded = encoded.transpose(0,1).mean(dim=1)
        pred = self.pred_out(encoded)
        print('pred size: ', pred.size())
        return pred

In [97]:
model = CongruenceFinder(len(t), n_chunks, embed_dim)

In [116]:
def loss_fn(logits):
    probs = torch.sigmoid(logits).detach().data.cpu().numpy()
    scale_vector = np.array([base**i for i in range(n_chunks)])
    
    first_num = probs[:,:n_chunks] @ scale_vector
    second_num = probs[:,n_chunks:] @ scale_vector
    
    first_squared = first_num**2
    second_squared = second_num**2
    
    print(first_squared.shape, second_squared.shape)
    

In [117]:
for batch in train_loader:
    tokenized = torch.tensor(t.encode(batch))
    logits = model(tokenized)
    print(logits.size())
    loss_fn(logits)
    break

pred size:  torch.Size([16, 12])
torch.Size([16, 12])
(16,) (16,)
