In [1]:
import torch
from torch import nn
import math

In [2]:
def generate_compact_dataset(num_samples=10000):
    data_tokens = torch.arange(1, 11) 
    
    all_inputs = []
    all_targets = []

    for _ in range(num_samples):
        perm = torch.randperm(10)
        sample_data = data_tokens[perm[:6]]
        
        is_relational = torch.rand(1) > 0.5
        
        if is_relational:
            cmd = torch.tensor([12])
            # Pick a key from the first 5 (so there is a neighbor at +1)
            key_idx = torch.randint(0, 5, (1,)).item()
            query = sample_data[key_idx].view(1)
            target = sample_data[key_idx + 1]
        else:
            # POSITIONAL: Input[7] is an Index (1-6); Target is data at that index
            cmd = torch.tensor([11])
            idx_to_pull = torch.randint(0, 6, (1,)).item()
            query = torch.tensor([idx_to_pull + 1])
            target = sample_data[idx_to_pull]

        full_input = torch.cat([sample_data, cmd, query])
        
        all_inputs.append(full_input)
        all_targets.append(target)

    return torch.stack(all_inputs), torch.stack(all_targets)

# Generate the 10,000 samples
inputs, targets = generate_compact_dataset(10000)

print(f"Dataset Shape: {inputs.shape}") # [10000, 8]
print(f"Sample 0 (Input): {inputs[0].tolist()} -> Target: {targets[0].item()}")

Dataset Shape: torch.Size([10000, 8])
Sample 0 (Input): [3, 10, 8, 5, 2, 1, 11, 4] -> Target: 5


In [3]:
indices = torch.randperm(len(inputs))

train_size = int(0.5*len(inputs))

train_idx = indices[:train_size]
test_idx = indices[train_size:]

train_inputs, train_targets = inputs[train_idx], targets[train_idx]
test_inputs,  test_targets  = inputs[test_idx],  targets[test_idx]

In [4]:
pe = torch.zeros([8, 64])
for j in range(3):
    for i in range(0, 64, 2):
        div_term = math.exp(-(math.log(10000.0) * i) / 64)
        pe[j, i] = math.sin(j * div_term)
        pe[j, i + 1] = math.cos(j * div_term)

In [5]:
class embeddings(nn.Module):
    def __init__(self,d: int ,  vocab_size:int = 13):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d)

    def forward(self, vocab: list, pe):
        te = self.token_emb(vocab)
        return te + pe.to(te.dtype)

In [6]:
class relative_matrix(nn.Module):
    def __init__(self, e_dim: int, **args):
        super().__init__()
        self.e_dim = e_dim
        #self.heads = head
        self.linear = nn.Linear(e_dim, e_dim)
        self.bias = nn.Parameter(torch.zeros(e_dim))

    @staticmethod
    def find_diff(V):
        _, s, _ = V.shape
        diff_4d = V.unsqueeze(2) - V.unsqueeze(1)
        mask = ~torch.eye(s, dtype=torch.bool, device=V.device)
        out = diff_4d[:, mask, :] 
        return out
        
    
    @staticmethod
    def normalise(a):
        ama = torch.max(a)
        ami = torch.min(a)
        return 2 * (a - ami) / (ama - ami + 1e-8) - 1

    def forward(self, x):
        x = self.linear(x)
        dm = relative_matrix.find_diff(x)
        mid = dm
        #print(mid.shape)
        b = x.shape[0] # 5000
        d = x.shape[2] # 64
        mid_reshaped = mid.view(b, 8, 7, d)
        out = mid_reshaped.sum(dim=2)
        #print(out.shape)
        out = out + self.bias
        return out

In [7]:
class attention_matrix(nn.Module):
    def __init__(self, e_dim):
        super().__init__()
        self.e_dim = e_dim

    def forward(self, x, y):
        return torch.softmax((x @ y.transpose(-2, -1))/math.sqrt(self.e_dim), -1)

In [11]:
class transformer(nn.Module):
    def __init__(self, e_dim: int):
        super().__init__()
        self.e_dim = e_dim
        self.embeddings = embeddings(e_dim, 13)
        self.relative = relative_matrix(self.e_dim)
        self.attention = attention_matrix(self.e_dim)
        self.V = nn.Parameter(torch.randn(e_dim, e_dim))
        nn.init.xavier_uniform_(self.V)
        self.linear = nn.Sequential(
            nn.Linear(e_dim, e_dim//2),
            nn.GELU(),
            nn.Linear(e_dim//2, e_dim),
            nn.GELU(),
            nn.Linear(e_dim, 13)
        )

    def forward(self, x):
        x_emb = self.embeddings(x, pe)
        a = self.relative(x_emb)
        value = x_emb @ self.V
        b = self.attention(x_emb, a)
        out = b @ value
        return b, self.linear(out)

In [12]:
epochs = 20000

model = transformer(64)
optimizer = torch.optim.AdamW(model.parameters(), lr = 0.001)

criterion = nn.CrossEntropyLoss()

In [13]:
values = []
for i in range(epochs):
    model.train()
    optimizer.zero_grad()

    att, logits = model(train_inputs)
    predicted = logits[:, -1, :]
    loss = criterion(predicted, train_targets)
    train_acc = (predicted.argmax(-1) == train_targets).float().mean()

    loss.backward()
    optimizer.step()

    if i%100==0:
        model.eval()

        with torch.no_grad():
            att_test, test_logits = model(test_inputs)
            test_predicted = test_logits[:, -1, :]

            test_loss = criterion(test_predicted, test_targets)
            test_acc = (test_predicted.argmax(-1) == test_targets).float().mean()

        temp = {
            "train_att": att,
            "test_att": att_test,
            "train_loss": loss,
            "train_acc": train_acc,
            "test_loss": test_loss,
            "test_acc": test_acc
        }
        o = f"| i: {i} | Train Accuracy: {train_acc} | Train Loss: {loss} | Test Accuracy: {test_acc} | Test Loss: {test_loss} |"
        print(o)
        if i%500==0:
            values.append(temp)
        with open('epoch_results_logging.log', 'a') as f:
            f.write(f"{o}\n")

        


| i: 0 | Train Accuracy: 0.09600000083446503 | Train Loss: 2.5641283988952637 | Test Accuracy: 0.10379999876022339 | Test Loss: 2.5561044216156006 |
| i: 100 | Train Accuracy: 0.39800000190734863 | Train Loss: 1.881314992904663 | Test Accuracy: 0.38940000534057617 | Test Loss: 1.914332389831543 |
| i: 200 | Train Accuracy: 0.43059998750686646 | Train Loss: 1.686747431755066 | Test Accuracy: 0.3986000120639801 | Test Loss: 1.7683637142181396 |
| i: 300 | Train Accuracy: 0.45559999346733093 | Train Loss: 1.55287766456604 | Test Accuracy: 0.41839998960494995 | Test Loss: 1.6753251552581787 |
| i: 400 | Train Accuracy: 0.4650000035762787 | Train Loss: 1.4736127853393555 | Test Accuracy: 0.4205999970436096 | Test Loss: 1.6403775215148926 |
| i: 500 | Train Accuracy: 0.48660001158714294 | Train Loss: 1.4065848588943481 | Test Accuracy: 0.41440001130104065 | Test Loss: 1.6631823778152466 |
| i: 600 | Train Accuracy: 0.49380001425743103 | Train Loss: 1.3742568492889404 | Test Accuracy: 0.42419

In [14]:
import json
def tensor_handler(obj):
    if isinstance(obj, torch.Tensor):
        return obj.detach().cpu().tolist()

with open("results_1.json", 'w') as f:
    json.dump(values, f, indent=4, default=tensor_handler)

In [15]:
torch.save(model.state_dict(), 'model_weights.pth')