In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
class GruCell(nn.Module):
    def __init__(self, hidden_size = 10):
        super(GruCell, self).__init__()
        self.my_reset = nn.Linear(hidden_size, hidden_size)
        self.my_update = nn.Linear(hidden_size, hidden_size)
        self.my_final = nn.Linear(hidden_size, hidden_size)
        self.neighbours_reset = nn.Linear(4 * hidden_size, hidden_size)
        self.neighbours_update = nn.Linear(4 * hidden_size, hidden_size)
        self.neighbours_final = nn.Linear(4 * hidden_size, hidden_size)

    def forward(self, state, x):
        reset_chooser = torch.sigmoid(self.my_reset(state) + self.neighbours_reset(x))
        resetted = reset_chooser * state
        update_chooser = torch.sigmoid(self.my_update(state) + self.neighbours_update(x))
        update = torch.tanh(self.my_final(resetted) + self.neighbours_final(x))
        update = update_chooser * update
        new_state = (1 - update_chooser) * state + update
        return reset_chooser, update_chooser, update, new_state

In [4]:
def autism_loss(reset, update, delta, state):
    return - torch.mean(reset ** 2) - torch.mean(update ** 2)  - torch.mean(state ** 2) + torch.mean(delta ** 2)

In [5]:
field_size = 102
hidden_size = 8
epochs = 100
batch_size = 32

In [6]:
net = GruCell(hidden_size).to(device)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [17]:
field = torch.randn(field_size, field_size, hidden_size)
# зануляем края
for i in range(field_size):
    for j in range(field_size):
        if i == 0 or j == 0 or i == field_size - 1 or j == field_size - 1:
            field[i, j] = torch.zeros(hidden_size)

In [18]:
for e in range(epochs):
    new_field = torch.zeros(field_size, field_size, hidden_size)
    train_set = []
    for i in range(1, field_size - 1):
        for j in range(1, field_size - 1):
            tmp = []
            tmp.append(field[i, j].to(device))
            # верхние и боковые соседи
            tmp.append(torch.cat([field[i - 1, j], field[i + 1, j], field[i, j - 1], field[i, j + 1]]).to(device))
            # координаты поля
            tmp.append((i, j))
            train_set.append(tmp)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
    i = 0
    running_loss = 0.0
    for cells, neighbours, cords in train_loader:
        loc_y, loc_x = cords
        optimizer.zero_grad()
        reset, update, delta, state = net(cells, neighbours)
        loss = autism_loss(reset, update, delta, state)
        loss.backward()
        optimizer.step()
        for new_state, y, x in zip(state, loc_y, loc_x):
            new_field[y, x] = new_state
        # print statistics
        running_loss += loss.item()
        i += 1
    mean_information = torch.mean(field ** 2)
    with torch.no_grad():
        field = new_field.clone()
    print(f'[{e + 1}] loss: {running_loss /  len(train_loader):.3f}; Mean Information : {torch.mean(field ** 2)}')

[1] loss: -1.126; Mean Information : 0.4543346166610718
[2] loss: -1.035; Mean Information : 0.30459609627723694
[3] loss: -1.050; Mean Information : 0.23837506771087646
[4] loss: -1.118; Mean Information : 0.21118928492069244
[5] loss: -1.221; Mean Information : 0.21041561663150787
[6] loss: -1.356; Mean Information : 0.23416127264499664
[7] loss: -1.516; Mean Information : 0.28507229685783386
[8] loss: -1.681; Mean Information : 0.36236658692359924
[9] loss: -1.816; Mean Information : 0.4514331817626953
[10] loss: -1.899; Mean Information : 0.5292965769767761
[11] loss: -1.941; Mean Information : 0.5854091644287109
[12] loss: -1.962; Mean Information : 0.6237891912460327
[13] loss: -1.974; Mean Information : 0.6517171859741211
[14] loss: -1.981; Mean Information : 0.6743824481964111
[15] loss: -1.985; Mean Information : 0.6937780976295471
[16] loss: -1.988; Mean Information : 0.7103978991508484
[17] loss: -1.990; Mean Information : 0.7241992354393005
[18] loss: -1.991; Mean Informati

In [20]:
torch.save(net.state_dict(), 'models/first_model')