In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
# Synthetic data generation
def generate_synthetic_data(num_points=20):
    t = np.sort(np.random.uniform(0, 1, num_points))  # Irregular time points
    x = np.sin(2 * np.pi * t) + np.random.normal(0, 0.1, num_points)
    return t, x

# Time embedding
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super(TimeEmbedding, self).__init__()
        self.dim = dim
        self.w = nn.Linear(1, dim)

    def forward(self, t):
        t = t.unsqueeze(-1)
        out = torch.cos(self.w(t))
        return out

In [None]:
# Multi-Time Attention Layer
class MultiTimeAttention(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads):
        super(MultiTimeAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.q_linear = nn.Linear(input_dim, embed_dim)
        self.k_linear = nn.Linear(input_dim, embed_dim)
        self.v_linear = nn.Linear(input_dim, embed_dim)

        self.time_embedding = TimeEmbedding(embed_dim)

    def forward(self, x, t):
        B, L, _ = x.shape
        
        q = self.q_linear(x)
        k = self.k_linear(x)
        v = self.v_linear(x)

        te = self.time_embedding(t)

        q = q.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        te = te.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)

        attn = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(self.head_dim)
        attn = attn + torch.matmul(q, te.transpose(-2, -1)) / np.sqrt(self.head_dim)
        attn = torch.softmax(attn, dim=-1)

        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(B, L, self.embed_dim)

        return out

In [None]:
# Improved mTAN model
class ImprovedMTAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads, num_layers):
        super(ImprovedMTAN, self).__init__()
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.attention_layers = nn.ModuleList([
            MultiTimeAttention(hidden_dim, hidden_dim, num_heads) 
            for _ in range(num_layers)
        ])
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, t, x):
        x = self.input_proj(x.unsqueeze(-1))
        for layer in self.attention_layers:
            x = layer(x, t) + x  # Residual connection
        return self.fc(x).squeeze(-1)

In [None]:
# Training function
def train_model(model, t, x, epochs=5000, lr=0.001):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=100, factor=0.5, verbose=True)
    t_tensor = torch.tensor(t).float().unsqueeze(0)
    x_tensor = torch.tensor(x).float().unsqueeze(0)
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        output = model(t_tensor, x_tensor)
        loss = nn.MSELoss()(output, x_tensor)
        loss.backward()
        optimizer.step()
        scheduler.step(loss)
        
        if (epoch + 1) % 1000 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

In [None]:
# Interpolation function
def interpolate(model, t, x, t_dense):
    with torch.no_grad():
        # t_dense에 해당하는 새로운 x 값을 생성합니다.
        x_dense = np.interp(t_dense, t, x)  

        # 원래 x 값을 사용하는 대신 보간된 x_dense를 사용합니다.
        t_dense_tensor = torch.tensor(t_dense).float().unsqueeze(0)
        x_dense_tensor = torch.tensor(x_dense).float().unsqueeze(0)
        
        # model에 t_dense_tensor와 x_dense_tensor를 전달합니다.
        output = model(t_dense_tensor, x_dense_tensor)
        return output.squeeze().numpy()

In [None]:
# Main experiment
t, x = generate_synthetic_data(num_points=30)

mtan_model = ImprovedMTAN(input_dim=1, hidden_dim=64, num_heads=4, num_layers=2)

print("Training mTAN model:")
train_model(mtan_model, t, x)

t_dense = np.linspace(0, 1, 1000)
mtan_interp = interpolate(mtan_model, t, x, t_dense)

In [None]:
# Plotting
plt.figure(figsize=(10, 6))
plt.plot(t_dense, np.sin(2 * np.pi * t_dense), 'r-', label='Ground truth')
plt.plot(t, x, 'kx', label='Observed data')
plt.plot(t_dense, mtan_interp, 'b-', label='mTAN interpolation')
plt.legend()
plt.title("mTAN Interpolation on Irregular Time Series")
plt.xlabel("Time")
plt.ylabel("Value")
plt.show()