# Создание модели

Модель сначала сжимает входные значения через Linear, затем LSTM извлекает последовательные зависимости. <br>
Attention позволяет выбрать наиболее значимые участки этой памяти. Маска позволяет игнорировать паддинги. <br>
В результате модель агрегирует информацию в контекстный вектор и делает регрессионное предсказание следующего значения <br>

In [2]:
from models.lstm_with_attention import LSTMWithAttention

### Визуализация модели

In [6]:
import torch
from torchview import draw_graph

# Параметры модели
input_dim = 1
hidden_dim = 128
num_layers = 2
seq_len = 20
batch_size = 2

x = torch.randn(batch_size, seq_len)
mask = torch.ones(batch_size, seq_len).bool()

# Инициализация модели
model = LSTMWithAttention(input_dim=input_dim, hidden_dim=hidden_dim, num_layers=num_layers, dropout=0.1, bidirectional=False)

print(f"Total parameters:     {model.count_parameters():,}")
print(f"Trainable parameters: {model.count_parameters():,}")


draw_graph(model, input_data=(x, mask), graph_name='LSTM with Attention', expand_nested=True, roll=True).visual_graph.render(format='svg')




Total parameters:     264,962
Trainable parameters: 264,962


'LSTM with Attention.gv.svg'

# Генерация данных


In [1]:
from data.generators import generate_data

data, labels, masks = generate_data(num_samples=1_000_000)
print("Data shape:", data.shape)
print("Labels shape:", labels.shape)
print("Masks shape:", masks.shape)

Generating sequences:  24%|██▍       | 237690/1000000 [00:15<00:51, 14867.33it/s]


KeyboardInterrupt: 

In [4]:
from data.dataset import build_dataloaders

train_loader, val_loader, test_loader = build_dataloaders(data, labels, masks, split=(0.8, 0.1, 0.1))

In [None]:
next(iter(train_loader))

# Обучение

In [None]:
import torch
from training.trainer import train_model

print('Torch CUDA status: \n\t%s' %
      '✅ Available' if torch.cuda.is_available() else '❌ NOT available', '\n')

# Обучение
trained_model, train_losses, val_losses = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=2_000,
    patience=20,
    lr=1e-3,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    checkpoint_every=50
)

Torch CUDA status: 
	✅ Available 



                                                                                        

[Epoch 1] Train Loss: 7957136.0464 | Val Loss: 7623519.1869
Example input: [7.8457403 2.8538585 2.0083718 1.5180016 1.6150894 1.5737698 1.6822139
 1.3469346 1.8272529 1.4052224 1.3979295 1.4048356 1.0697714 1.3430461
 1.3307256 0.9500988]
Example target: 1.2512695789337158
Example prediction: -0.14015483856201172


                                                                                        

[Epoch 2] Train Loss: 7187625.3435 | Val Loss: 6904709.3477


                                                                                        

[Epoch 3] Train Loss: 6509382.4695 | Val Loss: 6223053.2108


                                                                                        

[Epoch 4] Train Loss: 5917429.5591 | Val Loss: 5665147.0888


                                                                                        

[Epoch 5] Train Loss: 5423526.9223 | Val Loss: 5232672.6995


                                                                                        

[Epoch 6] Train Loss: 4996381.6710 | Val Loss: 4726086.2742


                                                                                        

[Epoch 7] Train Loss: 4529546.0835 | Val Loss: 4272316.4435


                                                                                        

[Epoch 8] Train Loss: 4247364.0840 | Val Loss: 3896796.4607


                                                                                        

[Epoch 9] Train Loss: 3865682.9015 | Val Loss: 3584806.0433


                                                                                         

[Epoch 10] Train Loss: 3579224.2274 | Val Loss: 3279558.8927


                                                                                         

[Epoch 11] Train Loss: 3280051.6591 | Val Loss: 2980968.2701


                                                                                         

[Epoch 12] Train Loss: 3008078.1782 | Val Loss: 2720398.8204


                                                                                         

[Epoch 13] Train Loss: 2835497.0900 | Val Loss: 2781793.2676


                                                                                         

[Epoch 14] Train Loss: 2565468.1585 | Val Loss: 2226185.6577


                                                                                         

[Epoch 15] Train Loss: 2453172.3503 | Val Loss: 2107748.4486


                                                                                         

[Epoch 16] Train Loss: 2358667.7150 | Val Loss: 1922177.4947


                                                                                         

[Epoch 17] Train Loss: 2088518.9918 | Val Loss: 1792082.8666


                                                                                         

[Epoch 18] Train Loss: 2096576.2416 | Val Loss: 1824506.6147


                                                                                         

[Epoch 19] Train Loss: 1977042.7256 | Val Loss: 1545631.0965


                                                                                         

[Epoch 20] Train Loss: 1791105.2993 | Val Loss: 1404366.5393


                                                                                         

[Epoch 21] Train Loss: 1688854.4196 | Val Loss: 1314974.8765


                                                                                         

[Epoch 22] Train Loss: 1585496.1097 | Val Loss: 1211513.6850


                                                                                         

[Epoch 23] Train Loss: 1521311.1871 | Val Loss: 1175109.4426


                                                                                         

[Epoch 24] Train Loss: 1463469.3427 | Val Loss: 1024020.8407


                                                                                         

[Epoch 25] Train Loss: 1366838.8935 | Val Loss: 1261856.6523


                                                                                         

[Epoch 26] Train Loss: 1363415.6447 | Val Loss: 1081075.2365


                                                                                         

[Epoch 27] Train Loss: 1302634.8457 | Val Loss: 923608.1406


                                                                                         

[Epoch 28] Train Loss: 1218429.1651 | Val Loss: 835073.0111


                                                                                         

[Epoch 29] Train Loss: 1144705.4867 | Val Loss: 795965.4185


                                                                                         

[Epoch 30] Train Loss: 1089783.5336 | Val Loss: 683721.5256


                                                                                         

[Epoch 31] Train Loss: 1038287.7025 | Val Loss: 702043.0234


                                                                                         

[Epoch 32] Train Loss: 990549.7339 | Val Loss: 662496.3730


                                                                                         

[Epoch 33] Train Loss: 946261.3119 | Val Loss: 695280.4111


                                                                                         

[Epoch 34] Train Loss: 915177.3741 | Val Loss: 651878.7291


                                                                                         

[Epoch 35] Train Loss: 878862.6178 | Val Loss: 601688.8259


[Epoch 36] Training:  32%|███▏      | 3962/12500 [00:32<01:02, 136.04it/s, loss=3.34e+5]

### Save model

In [None]:
# TODO: delete later
torch.save({
    "config": trained_model.get_config(),
    "state_dict": trained_model.state_dict(),
}, "weights/weird ones/HOPE_IT_WORKS_v2.pt")
torch.save(trained_model, "saved_models/lstm_full_v2.pt")
print("Model saved ✅")

NameError: name 'trained_model' is not defined

In [None]:
# config = {
#     "input_dim": model.input_linear.in_features,
#     "hidden_dim": model.lstm.hidden_size,
#     "num_layers": model.lstm.num_layers,
#     "dropout": model.lstm.dropout
# }

torch.save({
    # "config": config,
    "config": trained_model.get_config(),
    "state_dict": trained_model.state_dict(),
}, "weights/lstm_epoch234_val123.pt")
torch.save(trained_model, "saved_models/lstm_full_v2.pt")
print("Model saved ✅")

# Тестирование

In [None]:


def test_model(model, test_loader, device='cuda' if torch.cuda.is_available() else 'cpu'):
    model.eval()
    model.to(device)

    total_loss = 0.0
    criterion = torch.nn.MSELoss()
    predictions = []
    targets = []
    all_attn_weights = []

    with torch.no_grad():
        for x, mask, y in test_loader:
            x, mask, y = x.to(device), mask.to(device), y.to(device)
            output, attn_weights = model(x, mask)
            loss = criterion(output.squeeze(), y)
            total_loss += loss.item() * x.size(0)

            predictions.extend(output.squeeze().cpu().numpy())
            targets.extend(y.cpu().numpy())
            all_attn_weights.extend(attn_weights.cpu().numpy())

    avg_loss = total_loss / len(test_loader.dataset)
    print(f"\nTest MSE: {avg_loss:.4f}")
    return predictions, targets, all_attn_weights


def plot_loss_curves(train_losses, val_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()


def visualize_attention(attn_weights, sequence, mask=None, idx=0):
    weights = attn_weights[idx]
    values = sequence[idx].cpu().numpy()
    if mask is not None:
        weights = weights * mask[idx].cpu().numpy()

    plt.figure(figsize=(10, 2))
    plt.bar(range(len(weights)), weights, alpha=0.6)
    plt.title("Attention Weights")
    plt.xlabel("Time Step")
    plt.ylabel("Weight")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    print("Sequence Values:")
    print(values)
    print("\nAttention:")
    print(weights)

In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error

preds, targs, _ = test_model(trained_model, test_loader)

mse = mean_squared_error(targs, preds)
mae = mean_absolute_error(targs, preds)

print(f"Test MSE: {mse:.4f} — MAE: {mae:.4f}")

In [None]:
preds, targs, attns = test_model(trained_model, test_loader)
plot_loss_curves(train_losses, val_losses)
visualize_attention(attns, data, masks, idx=0)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(preds[:100], label='Predicted')
plt.plot(targs[:100], label='True', alpha=0.7)
plt.title('Model Predictions vs True Values (Sample)')
plt.xlabel('Sample Index')
plt.ylabel('Value')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# Используем исходные данные и маску из датасета
test_batch = next(iter(test_loader))
x_batch, mask_batch, y_batch = test_batch

# Показываем attention и последовательность
visualize_attention(attns, x_batch, mask_batch, idx=0)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(6, 6))
plt.scatter(targs, preds, alpha=0.5)
plt.plot([min(targs), max(targs)], [min(targs), max(targs)], 'r--', label='Ideal')
plt.xlabel('True Value')
plt.ylabel('Predicted Value')
plt.title('Predicted vs True Values')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
import numpy as np

N = 10  # первые 10 примеров
x = np.arange(N)

plt.figure(figsize=(10, 4))
plt.bar(x - 0.2, [targs[i] for i in x], width=0.4, label='True')
plt.bar(x + 0.2, [preds[i] for i in x], width=0.4, label='Predicted')
plt.xlabel('Sample Index')
plt.ylabel('Value')
plt.title('True vs Predicted (first 10 samples)')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
def visualize_attention_bars(attn_weights, input_sequence, mask=None, idx=0):
    import matplotlib.pyplot as plt
    import seaborn as sns

    weights = attn_weights[idx]
    values = input_sequence[idx].cpu().numpy()
    if mask is not None:
        weights = weights * mask[idx].cpu().numpy()

    plt.figure(figsize=(12, 2.5))
    sns.barplot(x=np.arange(len(values)), y=weights, palette='coolwarm', alpha=0.6)
    plt.xticks(ticks=np.arange(len(values)), labels=[f'{v:.2f}' for v in values], rotation=45)
    plt.title("Attention weights per timestep (values shown below)")
    plt.xlabel("Sequence Element (Value)")
    plt.ylabel("Attention Weight")
    plt.tight_layout()
    plt.show()

visualize_attention_bars(attns, data, masks, idx=0)

In [None]:
# посмотреть среднее распределение attention по маске
attn_sums = torch.stack([torch.tensor(a) for a in attns])
attn_mean = attn_sums.mean(dim=0)
plt.plot(attn_mean.numpy()); plt.title("Mean Attention Weight by Position"); plt.show()