# Лабораторная работа № 8

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import tqdm
import matplotlib.pyplot as plt
from collections import deque

from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd
import plotly.express as px

In [2]:
def plot_history(h):
    fig = go.Figure()
    ax = np.arange(0, len(h), 1)
    fig.add_trace(go.Scatter(x=ax, y=h))
    fig.update_layout(title_text=f"Loss: {h[-1]:.6f}")
    fig.show()

# NARX

In [3]:
class TDL(nn.Module):
    def __init__(self, in_features, delays):
        super(TDL, self).__init__()
        self.in_features = in_features
        self.delays = delays
        self.line = deque()
        self.clear()
        
    def clear(self):
        self.line.clear()
        for _ in range(self.delays):
            self.line.append(torch.zeros(1, self.in_features))
    
    def push(self, x):
        self.line.appendleft(x)
    
    def forward(self):
        return self.line.pop()

In [4]:
class NARX(nn.Module):
    def __init__(self, in_features, hid_features, out_features, in_delay, out_delay):
        super(NARX, self).__init__()        
        self.in_tdl = TDL(in_features, in_delay)
        self.out_tdl = TDL(out_features, out_delay)
        
        self.w1 = nn.Parameter(torch.randn(in_features, hid_features))
        self.b1 = nn.Parameter(torch.zeros(hid_features))
        
        self.w2 = nn.Parameter(torch.randn(out_features, hid_features))
        
        self.w3 = nn.Parameter(torch.randn(hid_features, out_features))
        self.b3 = nn.Parameter(torch.zeros(out_features))
        
    def clear(self):
        self.in_tdl.clear()
        self.out_tdl.clear()
        
    def forward(self, x):
        out = torch.tanh(self.in_tdl() @ self.w1 + self.b1 + self.out_tdl() @ self.w2)
        out = out @ self.w3 + self.b3
        
        self.in_tdl.push(x.detach().clone())
        self.out_tdl.push(out.detach().clone())
        return out

## Данные

In [5]:
f = lambda k: np.sin(k**2 - 6*k - 2*np.pi)/4
h = 0.01
k = np.arange(0, 10+h, h)

u = f(k)
y = [0]
for i in range(0, len(k)-1):
    y.append(y[-1] / (1 + y[-1] ** 2) + u[i] ** 3)

In [6]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=k, y=u, name='Input signal'))
fig.add_trace(go.Scatter(x=k, y=y, name='Target signal'))
fig.show()

## Обучение и тестирование сети

In [7]:
def fit(model, optim, crit, epochs, data):
    model.train()
    train_loss = []
    pbar = tqdm.trange(epochs, ascii=True)
    for i in pbar:
        model.clear()
        avg_loss = 0
        for X_batch, Y_batch in data:        
            optim.zero_grad()
            
            output = model(X_batch)
            loss = crit(Y_batch, output)
            loss.backward()
            
            optim.step()
            avg_loss += loss.item() / len(data)
        train_loss.append(avg_loss)
        pbar.set_description(f'Epoch: {i+1}. Loss: {avg_loss:.8f}')
    return train_loss
        
def predict(model, data, window):
    model.eval()
    model.clear()
    with torch.no_grad():
        pred = [*model(next(iter(data))[0]).detach().numpy()[0, :window-1]]
        model.clear()
        for X, _ in data:
            pred.append(model(X).detach().numpy().item(-1))
    return pred

In [8]:
def plot_result(model, data, window, k, y):
    pred = predict(model, data, window)
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=k, y=y, name='Original signal'))
    fig.add_trace(go.Scatter(x=k, y=pred, name='Predict signal'))
    fig.show()

In [33]:
window = 8
train_data = [(np.array(u[i:i + window], dtype=np.float32), np.array(y[i:i + window], dtype=np.float32)) for i in range(len(k) - window + 1)]
train_loader = DataLoader(train_data, batch_size=1, shuffle=False)

In [34]:
model = NARX(window, 10, window, 2, 3)
hist = fit(model, torch.optim.Adam(model.parameters(), lr=1e-4), nn.MSELoss(), 60, train_loader)

Epoch: 60. Loss: 0.00184372: 100%|##########| 60/60 [00:36<00:00,  1.66it/s]


In [35]:
plot_history(hist)

In [36]:
plot_result(model, train_loader, window, k, y)

In [13]:
window = 9
train_data = [(np.array(u[i:i + window], dtype=np.float32), np.array(y[i:i + window], dtype=np.float32)) for i in range(len(k) - window + 1)]
train_loader = DataLoader(train_data, batch_size=1, shuffle=False)

In [14]:
model = NARX(window, 15, window, 3, 2)
hist = fit(model, torch.optim.Adam(model.parameters(), lr=1e-4), nn.MSELoss(), 70, train_loader)

Epoch: 70. Loss: 0.00259667: 100%|##########| 70/70 [00:44<00:00,  1.58it/s]


In [15]:
plot_history(hist)

In [16]:
plot_result(model, train_loader, window, k, y)

In [17]:
window = 9
train_data = [(np.array(u[i:i + window], dtype=np.float32), np.array(y[i:i + window], dtype=np.float32)) for i in range(len(k) - window + 1)]
train_loader = DataLoader(train_data, batch_size=1, shuffle=False)

In [18]:
model = NARX(window, 14, window, 3, 2)
hist = fit(model, torch.optim.Adam(model.parameters(), lr=1e-3), nn.MSELoss(), 100, train_loader)

Epoch: 100. Loss: 0.00474418: 100%|##########| 100/100 [01:02<00:00,  1.59it/s]


In [19]:
plot_history(hist)

In [20]:
plot_result(model, train_loader, window, k, y)

In [21]:
window = 3
train_data = [(np.array(u[i:i + window], dtype=np.float32), np.array(y[i:i + window], dtype=np.float32)) for i in range(len(k) - window + 1)]
train_loader = DataLoader(train_data, batch_size=1, shuffle=False)

In [22]:
model = NARX(window, 20, window, 2, 2)
hist = fit(model, torch.optim.Adam(model.parameters(), lr=1e-3), nn.MSELoss(), 50, train_loader)

Epoch: 50. Loss: 0.00576322: 100%|##########| 50/50 [00:31<00:00,  1.59it/s]


In [23]:
plot_history(hist)

In [24]:
plot_result(model, train_loader, window, k, y)

In [25]:
window = 8
train_data = [(np.array(u[i:i + window], dtype=np.float32), np.array(y[i:i + window], dtype=np.float32)) for i in range(len(k) - window + 1)]
train_loader = DataLoader(train_data, batch_size=1, shuffle=False)

In [30]:
model = NARX(window, 10, window, 2, 3)
hist = fit(model, torch.optim.Adam(model.parameters(), lr=1e-5), nn.MSELoss(), 800, train_loader)

Epoch: 800. Loss: 0.00038858: 100%|##########| 800/800 [08:05<00:00,  1.65it/s]


In [31]:
plot_history(hist)

In [32]:
plot_result(model, train_loader, window, k, y)