In [32]:
import torch
import torch.nn as nn
import torch.optim as optim

class DNC(nn.Module):
    def __init__(self, input_size, hidden_size, memory_size, memory_vector_dim, num_read_heads=1):
        super(DNC, self).__init__()
        self.hidden_size = hidden_size
        self.memory_size = memory_size
        self.memory_vector_dim = memory_vector_dim
        self.num_read_heads = num_read_heads

        self.controller = nn.LSTMCell(input_size + num_read_heads * memory_vector_dim, hidden_size)
        self.memory = nn.Parameter(torch.zeros(memory_size, memory_vector_dim))
        self.read_heads = nn.ModuleList([nn.Linear(hidden_size, memory_size) for _ in range(num_read_heads)])
        self.write_head = nn.Linear(hidden_size, memory_size)
        self.erase_head = nn.Linear(hidden_size, memory_vector_dim)
        self.add_head = nn.Linear(hidden_size, memory_vector_dim)
        self.output = nn.Linear(hidden_size + num_read_heads * memory_vector_dim, input_size)

    def forward(self, x, hidden=None):
        batch_size = x.size(0)
        seq_length = x.size(1)

        if hidden is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)
            c = torch.zeros(batch_size, self.hidden_size, device=x.device)
            hidden = (h, c)

        outputs = []
        read_vectors = [torch.zeros(batch_size, self.memory_vector_dim, device=x.device) for _ in range(self.num_read_heads)]

        for i in range(seq_length):
            xi = x[:, i, :]
            controller_input = torch.cat([xi] + read_vectors, dim=1)
            hidden = self.controller(controller_input, hidden)
            h, c = hidden

            read_weights = [torch.softmax(head(h), dim=1) for head in self.read_heads]
            read_vectors = [torch.bmm(w.unsqueeze(1), self.memory.expand(batch_size, -1, -1)).squeeze(1) for w in read_weights]

            write_weights = torch.softmax(self.write_head(h), dim=1)
            erase_vector = torch.sigmoid(self.erase_head(h))
            add_vector = self.add_head(h)

            erase = torch.bmm(write_weights.unsqueeze(2), erase_vector.unsqueeze(1))
            add = torch.bmm(write_weights.unsqueeze(2), add_vector.unsqueeze(1))
            self.memory.data = self.memory * (1 - erase.mean(0)) + add.mean(0)

            output = self.output(torch.cat([h] + read_vectors, dim=1))
            outputs.append(output.unsqueeze(1))

        return torch.cat(outputs, dim=1)

# Hyperparameters
input_size = 1
hidden_size = 64
memory_size = 128
memory_vector_dim = 32
num_read_heads = 1
learning_rate = 0.001
num_epochs = 10000
batch_size = 32
seq_length = 5

# Create the model
model = DNC(input_size, hidden_size, memory_size, memory_vector_dim, num_read_heads)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

# Training loop
for epoch in range(num_epochs):
    # Generate a batch of random lists of integers
    input_batch = torch.randint(1, 101, (batch_size, seq_length, input_size)).float()
    target_batch = input_batch.flip(1)

    # Forward pass
    output = model(input_batch)
    loss = criterion(output, target_batch)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 1000 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Test the model
test_input = torch.tensor([[[1], [2], [3], [4], [5]]]).float()
test_output = model(test_input)
print("Input:", test_input.squeeze().tolist())
print("Output:", [round(x) for x in test_output.squeeze().tolist()])



# Generate 100 lists of 5 random integers
# data = []
# for _ in range(10):
#     data.append(torch.tensor([[random.randint(1,100)] for _ in range(5)], dtype=torch.int64))

# targets = [item.flip(0) for item in data]



Epoch [1000/10000], Loss: 798.9494
Epoch [2000/10000], Loss: 733.3624
Epoch [3000/10000], Loss: 509.8066
Epoch [4000/10000], Loss: 376.6555
Epoch [5000/10000], Loss: 311.6450
Epoch [6000/10000], Loss: 356.1513
Epoch [7000/10000], Loss: 336.5067
Epoch [8000/10000], Loss: 341.1594
Epoch [9000/10000], Loss: 303.1803
Epoch [10000/10000], Loss: 299.5796
Input: [1.0, 2.0, 3.0, 4.0, 5.0]
Output: [49, 49, 6, 7, 7]
