In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import seaborn as sns
import matplotlib.pylab as plt

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
class GruCell(nn.Module):
    def __init__(self, hidden_size = 10):
        super(GruCell, self).__init__()
        self.my_reset = nn.Linear(hidden_size, hidden_size)
        self.my_update = nn.Linear(hidden_size, hidden_size)
        self.my_final = nn.Linear(hidden_size, hidden_size)
        self.neighbours_reset = nn.Linear(4 * hidden_size, hidden_size)
        self.neighbours_update = nn.Linear(4 * hidden_size, hidden_size)
        self.neighbours_final = nn.Linear(4 * hidden_size, hidden_size)

    def forward(self, state, x):
        reset_chooser = torch.sigmoid(self.my_reset(state) + self.neighbours_reset(x))
        resetted = reset_chooser * state
        update_chooser = torch.sigmoid(self.my_update(state) + self.neighbours_update(x))
        update = torch.tanh(self.my_final(resetted) + self.neighbours_final(x))
        update = update_chooser * update
        new_state = (1 - update_chooser) * state + update
        return reset_chooser, update_chooser, update, new_state

In [4]:
def autism_loss(reset, update, delta, state):
    return -torch.mean(reset ** 2) - torch.mean(update ** 2)  - torch.mean(state ** 2)  + torch.mean(delta ** 2)

def anti_autism_loss(reset, update, delta, state):
    return -torch.mean(reset ** 2) - torch.mean(update ** 2)  + torch.mean(state ** 2)  + torch.mean(delta ** 2)

def distance_loss(states, x, y):
    x = x.to(device)
    y = y.to(device)
    middle = int(x.shape[0] / 2)
    x1 = x[:middle]
    x2 = x[middle:]
    y1 = y[:middle]
    y2 = y[middle:]
    states1 = states[:middle]
    states2 = states[middle:]
    range_deltas = ((x2-x1) ** 2 + (y2-y1) ** 2 - torch.sum((states2 - states1) ** 2, axis = 1)) / 1000
    return torch.mean(range_deltas ** 2) ** 0.5 
    

def prefer_biggest_loss(reset, update, delta, state):
    priority = torch.mean(state ** 2, axis = 1)
    priority = priority.reshape(delta.shape[0], 1)
    return - torch.mean(priority * reset ** 2) - torch.mean(priority * update ** 2)  + torch.mean(priority * delta ** 2)

In [5]:
field_size = 102
hidden_size = 8
epochs = 15
batch_size = 32

In [25]:
net_pos = GruCell(hidden_size).to(device)
net_neg = GruCell(hidden_size).to(device)
optimizer_pos = optim.Adam(net_pos.parameters(), lr=0.0001, weight_decay = 0.001)
optimizer_neg = optim.Adam(net_neg.parameters(), lr=0.0001)

In [26]:
field = torch.randn(field_size, field_size, hidden_size)
# зануляем края
for i in range(field_size):
    for j in range(field_size):
        if i == 0 or j == 0 or i == field_size - 1 or j == field_size - 1:
            field[i, j] = torch.zeros(hidden_size)

In [None]:
pump_energy = False
for e in range(epochs):
    new_field = torch.zeros(field_size, field_size, hidden_size)
    train_set = []
    for i in range(1, field_size - 1):
        for j in range(1, field_size - 1):
            tmp = []
            tmp.append(field[i, j].to(device))
            # верхние и боковые соседи
            tmp.append(torch.cat([field[i - 1, j], field[i + 1, j], field[i, j - 1], field[i, j + 1]]).to(device))
            # координаты поля
            tmp.append((i, j))
            train_set.append(tmp)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
    running_loss = 0.0
    for cells, neighbours, cords in train_loader:
        loc_y, loc_x = cords
        optimizer_pos.zero_grad()
        optimizer_neg.zero_grad()
        reset_pos, update_pos, delta_pos, state_pos = net_pos(cells, neighbours)
        reset_neg, update_neg , delta_neg, state_neg = net_neg(cells, neighbours)
        mask = delta_pos > delta_neg
        loss_neg = anti_autism_loss(mask * reset_neg, mask *  update_neg, mask *  delta_neg, mask *  state_neg) 
        loss_pos = autism_loss(~mask * reset_pos, ~mask * update_pos, ~mask * delta_pos, ~mask * state_pos) * 0.5
        loss_pos.backward()
        loss_neg.backward()
        optimizer_pos.step()
        optimizer_neg.step()
        with torch.no_grad():
            for new_state, y, x, m in zip(state_pos, loc_y, loc_x, mask.cpu().numpy()):
                if m[0]:
                    new_field[y, x] = new_state
            for new_state, y, x, m in zip(state_neg, loc_y, loc_x, mask.cpu().numpy()):
                if not m[0]:
                    new_field[y, x] = new_state.to('cpu')
        # print statistics
        running_loss += loss_pos.item()
    mean_information = torch.mean(field ** 2)
    with torch.no_grad():
        field = new_field.clone()
        fig, ax = plt.subplots(figsize=(10,10))   
        ax = sns.heatmap(torch.mean(field ** 2, axis = 2), ax=ax, square=True,  vmin = 0, vmax = 1)
        plt.show()
        if pump_energy:
            # эксмерементирую с константой, сколько нового рандома вносить
            field += torch.randn(field_size, field_size, hidden_size) * 0.3
    print(f'[{e + 1}] loss: {running_loss /  len(train_loader):.3f}; Mean Information : {torch.mean(field ** 2)}')

In [None]:
torch.save(net.state_dict(), 'models/first_model')