In [30]:
import torch
from torch import nn
import math

In [12]:
def generate_compact_dataset(num_samples=10000):
    data_tokens = torch.arange(1, 11) 
    
    all_inputs = []
    all_targets = []

    for _ in range(num_samples):
        perm = torch.randperm(10)
        sample_data = data_tokens[perm[:6]]
        
        is_relational = torch.rand(1) > 0.5
        
        if is_relational:
            cmd = torch.tensor([12])
            # Pick a key from the first 5 (so there is a neighbor at +1)
            key_idx = torch.randint(0, 5, (1,)).item()
            query = sample_data[key_idx].view(1)
            target = sample_data[key_idx + 1]
        else:
            # POSITIONAL: Input[7] is an Index (1-6); Target is data at that index
            cmd = torch.tensor([11])
            idx_to_pull = torch.randint(0, 6, (1,)).item()
            query = torch.tensor([idx_to_pull + 1])
            target = sample_data[idx_to_pull]

        full_input = torch.cat([sample_data, cmd, query])
        
        all_inputs.append(full_input)
        all_targets.append(target)

    return torch.stack(all_inputs), torch.stack(all_targets)

# Generate the 10,000 samples
inputs, targets = generate_compact_dataset(10000)

print(f"Dataset Shape: {inputs.shape}") # [10000, 8]
print(f"Sample 0 (Input): {inputs[0].tolist()} -> Target: {targets[0].item()}")

Dataset Shape: torch.Size([10000, 8])
Sample 0 (Input): [3, 2, 6, 5, 9, 8, 12, 5] -> Target: 9


In [13]:
indices = torch.randperm(len(inputs))

train_size = int(0.5*len(inputs))

train_idx = indices[:train_size]
test_idx = indices[train_size:]

train_inputs, train_targets = inputs[train_idx], targets[train_idx]
test_inputs,  test_targets  = inputs[test_idx],  targets[test_idx]

In [14]:
class embeddings(nn.Module):
    def __init__(self,d: int ,  vocab_size:int = 13):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d)

    def forward(self, vocab: list):
        te = self.token_emb(vocab)
        return te

In [34]:
class relative_matrix(nn.Module):
    def __init__(self, e_dim: int, **args):
        super().__init__()
        self.e_dim = e_dim
        #self.heads = head
        self.linear = nn.Linear(e_dim, e_dim)
        self.bias = nn.Parameter(torch.zeros(e_dim))

    @staticmethod
    def find_diff(V):
        _, s, _ = V.shape
        diff_4d = V.unsqueeze(2) - V.unsqueeze(1)
        mask = ~torch.eye(s, dtype=torch.bool, device=V.device)
        out = diff_4d[:, mask, :] 
        return out
    
    @staticmethod
    def tan_pos():
        tv_c = torch.linspace(-torch.pi/2 + 2.2e-1, 2.2e-1,8)
        tv_t = torch.linspace(2.2e-1, torch.pi/2 - 2.2e-1, 8)
        #print(tv_c)
        #print(tv_t)
        i, j = torch.where(~torch.eye(8, dtype=torch.bool))
        #print(i)
        #print(j)
        mark = j-i
        #print(mark)
        out = torch.tensor([1/torch.tan(tv_c[i]) if i>0 else torch.tan(tv_t[i]) for i in mark])
        return out
        
    
    @staticmethod
    def normalise(a):
        ama = torch.max(a)
        ami = torch.min(a)
        return 2 * (a - ami) / (ama - ami + 1e-8) - 1

    def forward(self, x):
        x = self.linear(x)
        dm = relative_matrix.find_diff(x)
        tm = relative_matrix.tan_pos()
        tm = relative_matrix.normalise(tm)
        mid = dm+tm.unsqueeze(1)
        #print(mid.shape)
        b = x.shape[0] # 5000
        d = x.shape[2] # 64
        mid_reshaped = mid.view(b, 8, 7, d)
        out = mid_reshaped.sum(dim=2)
        #print(out.shape)
        out = out + self.bias
        return out

In [31]:
class attention_matrix(nn.Module):
    def __init__(self, e_dim):
        super().__init__()
        self.e_dim = e_dim

    def forward(self, x, y):
        return torch.softmax((x @ y.transpose(-2, -1))/math.sqrt(self.e_dim), -1)

In [27]:
class transformer(nn.Module):
    def __init__(self, e_dim: int):
        super().__init__()
        self.e_dim = e_dim
        self.embeddings = embeddings(e_dim, 13)
        self.relative = relative_matrix(self.e_dim)
        self.attention = attention_matrix(self.e_dim)
        self.V = nn.Parameter(torch.randn(e_dim, e_dim))
        nn.init.xavier_uniform_(self.V)
        self.linear = nn.Sequential(
            nn.Linear(e_dim, e_dim//2),
            nn.GELU(),
            nn.Linear(e_dim//2, e_dim),
            nn.GELU(),
            nn.Linear(e_dim, 13)
        )

    def forward(self, x):
        x_emb = self.embeddings(x)
        a = self.relative(x_emb)
        value = x_emb @ self.V
        b = self.attention(x_emb, a)
        out = b @ value
        return b, self.linear(out)

In [35]:
epochs = 20000

model = transformer(64)
optimizer = torch.optim.AdamW(model.parameters(), lr = 0.001)

criterion = nn.CrossEntropyLoss()

In [36]:
values = []
for i in range(epochs):
    model.train()
    optimizer.zero_grad()

    att, logits = model(train_inputs)
    predicted = logits[:, -1, :]
    loss = criterion(predicted, train_targets)
    train_acc = (predicted.argmax(-1) == train_targets).float().mean()

    loss.backward()
    optimizer.step()

    if i%100==0:
        model.eval()

        with torch.no_grad():
            att_test, test_logits = model(test_inputs)
            test_predicted = test_logits[:, -1, :]

            test_loss = criterion(test_predicted, test_targets)
            test_acc = (test_predicted.argmax(-1) == test_targets).float().mean()

        temp = {
            "train_att": att,
            "test_att": att_test,
            "train_loss": loss,
            "train_acc": train_acc,
            "test_loss": test_loss,
            "test_acc": test_acc
        }
        o = f"| i: {i} | Train Accuracy: {train_acc} | Train Loss: {loss} | Test Accuracy: {test_acc} | Test Loss: {test_loss} |"
        print(o)
        if i%500==0:
            values.append(temp)
        with open('epoch_results_logging.log', 'a') as f:
            f.write(f"{o}\n")

        


| i: 0 | Train Accuracy: 0.09279999881982803 | Train Loss: 2.570821762084961 | Test Accuracy: 0.09619999676942825 | Test Loss: 2.5642614364624023 |
| i: 100 | Train Accuracy: 0.2671999931335449 | Train Loss: 2.089670181274414 | Test Accuracy: 0.2386000007390976 | Test Loss: 2.1455256938934326 |
| i: 200 | Train Accuracy: 0.29760000109672546 | Train Loss: 1.9460054636001587 | Test Accuracy: 0.25940001010894775 | Test Loss: 2.0316286087036133 |
| i: 300 | Train Accuracy: 0.3089999854564667 | Train Loss: 1.8736159801483154 | Test Accuracy: 0.2624000012874603 | Test Loss: 2.010807514190674 |
| i: 400 | Train Accuracy: 0.3253999948501587 | Train Loss: 1.8219720125198364 | Test Accuracy: 0.25279998779296875 | Test Loss: 2.017420530319214 |
| i: 500 | Train Accuracy: 0.3312000036239624 | Train Loss: 1.7833935022354126 | Test Accuracy: 0.2492000013589859 | Test Loss: 2.034996747970581 |
| i: 600 | Train Accuracy: 0.33660000562667847 | Train Loss: 1.7593823671340942 | Test Accuracy: 0.245800003

In [None]:
import json
def tensor_handler(obj):
    if isinstance(obj, torch.Tensor):
        return obj.detach().cpu().tolist()

with open("results_1.json", 'w') as f:
    json.dump(values, f, indent=4, default=tensor_handler)

In [None]:
torch.save(model.state_dict(), 'model_weights.pth')