In [5]:
from main.lstm_ae import train_model as lstm_train_model
from main.policy import train_model as policy_gradient
from models.default_model import decoder_model
from models.LSTM_AE import LSTMAutoencoder
from tasks.many_rukzaks import Task, DistributionType

import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Categorical

import os
import matplotlib.pyplot as plt

In [6]:
params = {
        'n':3,
        'k':100,
        'distrs':(DistributionType.UNIFORM, DistributionType.UNIFORM, DistributionType.UNIFORM),
        'volume_params':{'a': 40, 'b': 60},
        'cost_params':{'a': 15, 'b': 25},
        'item_volume_params':{'a': 8, 'b': 12}
}
task = Task(**params)
print(task)

Task(n=3, k=100)
Рюкзаки: [44.24981921 45.01146004 45.86622999]
Стоимости: [15.55635065 21.96662589 15.18771919 15.47746842 24.12846368 23.49165606
 18.57142494 15.36464806 18.78717261 22.09674442 19.10834388 19.14590638
 20.33415814 22.81532348 20.52064598 18.61837363 19.67060819 15.89826946
 20.2248116  17.14693689 23.26091825 22.93226544 24.17831153 17.34256719
 21.94211128 21.38788297 17.58436578 15.95708445 22.49160408 23.20252061
 24.39197453 15.71403007 18.27775377 15.20402489 17.15833459 21.76495634
 24.0659858  23.42753026 15.58650208 15.95778123 18.66880888 24.22368189
 21.46433799 21.49640316 18.92183911 22.07439434 18.30338597 24.13815385
 23.42923051 20.28968117 17.57092586 23.58343472 19.87016413 15.22272319
 24.94738276 17.17096789 18.43580433 17.02826909 15.79026127 15.82635518
 19.20058341 19.47447879 16.04392986 16.5581578  20.83699603 15.95556489
 17.48573684 22.84859779 24.55338826 21.38794364 17.0314813  18.90356389
 19.60457747 20.07362068 22.6012645  18.02284603 

### Encoders training

In [7]:
hidden_dim_lstm = 16
encoder_models = [LSTMAutoencoder(input_dim=1, hidden_dim=hidden_dim_lstm) 
                  for _ in range(3)]
optimizers = [torch.optim.Adam(encoder_models[i].parameters(), lr=0.01, eps=1e-8, weight_decay=0.001) for i in range(3)]
criterion = nn.MSELoss()
num_epochs = [512, 128, 256]
batch_size = 32

os.makedirs('encoder_models', exist_ok=True)

for idx, object in enumerate(['r_volume', 'it_volume', 'it_cost']):
    print(idx, "model")
    model_path = f'encoder_models/encoder_{object}.pth'
    if not os.path.exists(model_path):
        encoder_models[idx], _ = lstm_train_model(
            encoder_models[idx],
            optimizers[idx],
            criterion,
            Task,
            params,
            num_epochs=num_epochs[idx],
            batch_size=batch_size,
            object=object,
            verbose=True)

        torch.save(encoder_models[idx].state_dict(), model_path)
        print(f"Модель {object} сохранена в {model_path}")
    else:
        encoder_models[idx].load_state_dict(torch.load(model_path))
        encoder_models[idx].eval()

0 model
1 model
2 model


### Decoders Training

In [9]:
hidden_size = 32

decoder1 = decoder_model(hidden_dim_lstm*3, hidden_size)
decoder2 = decoder_model(hidden_dim_lstm*3 + 1, hidden_size)
optimizer1 = torch.optim.Adam(decoder1.parameters(), lr=0.01, eps=1e-8, weight_decay=0.001)
optimizer2 = torch.optim.Adam(decoder2.parameters(), lr=0.05, eps=1e-8, weight_decay=0.001)
criterion = nn.MSELoss()

num_epochs = 16
samples = 128
decoder1, decoder2, losses, mean_ep_steps = policy_gradient(
    decoder1,
    decoder2,
    optimizer1,
    optimizer2,
    Task,
    params,
    num_epochs=num_epochs, samples=samples, encoders=encoder_models,)


plt.plot(losses)
plt.show()
plt.plot(mean_ep_steps)
plt.show()

  0%|          | 0/16 [00:05<?, ?it/s]


KeyboardInterrupt: 

### Checking

In [5]:
task = Task(**params)
greedy = task.solve_dynamically()
cont = True

decoder1.train()
decoder2.train()
while cont and task.not_end():
    state = task.get_state()
    n = state[0].shape[0]
    k = state[1].shape[0]
    embeddings = []
    for idx, model in enumerate(encoder_models):
        model.eval()
        with torch.no_grad():
            _, z = model.encode_zero(torch.from_numpy(state[idx]).to(torch.float32)[None, :, None])
            embeddings.append(z)

    #print("embs:", embeddings[0].shape)
    state1 = torch.cat(embeddings, dim=-1).unsqueeze(1)
    #print("state1:", state1.shape)                #print("n:", n)
    logits1 = decoder1(state1, n).squeeze(-1) # size:  - pi(a_1|s), pi(a_2|a_1, s)
    #print("logits:", logits1.shape)
    dist1 = Categorical(logits=logits1)

    a1 = dist1.sample()
    #print("a:", a1)
    #print('st_a1:',state[0][a1] )
    action = torch.tensor([state[0][a1]]).to(torch.float32)[:, None]
    #print("action:", action.shape)
    embeddings.append(action)
    state2 = torch.cat(embeddings, dim=-1).unsqueeze(1)
    logits2 = decoder2(state2, k).squeeze(-1)
    mask = (state[0][a1] >= state[1])  # [k]
    #print(mask)
    mask = np.reshape(mask, (1, -1))  # True там, где разрешено
    mask = torch.tensor(mask, dtype=torch.bool)
    masked_logits = logits2.clone()
    masked_logits[~mask] = -1e9  # запрещённые позиции -> -inf

    dist2 = Categorical(logits=masked_logits)

    #dist2 = Categorical(logits=logits2)

    a2 = dist2.sample()

    cont = task.take_action(a1, a2)
print(greedy, task.total_sum)

367.15601572979006 319.96547322542466
