In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import torch

from data_loader import DataLoader

# Data Loading

In [None]:
data_loader = DataLoader()

In [None]:
# MTS data loading
mts_datasets = data_loader.get_mts_datasets()
X_train, y_train = data_loader.load_mts_dataset(mts_datasets[1], split="train")
X_test, y_test = data_loader.load_mts_dataset(mts_datasets[1], split="test")
print("Train", X_train.shape, y_train.shape)
print("Test", X_test.shape, y_test.shape)

In [None]:
# Text data loading
X, y = data_loader.load_text_dataset("data")
print(X[:5])
print(X.shape)
print(y[:5])
print(y.shape)

# Time Series

In [None]:
from encoder import CausalCNNEncoder
# Model parameters

# CNN parameters
# Input channels is always 1 (since we are using 1D convolutions)
in_channels = 1
# Hidden channels within the CNN layers
channels = 20
depth = 3
# Output size of the convolutional layers
reduced_size = 80
# Convolution kernel size
kernel_size = 3

# Encoder parameters
# Output dimensionality of the encoder
out_channels = 160

causal_cnn = CausalCNNEncoder(in_channels,
                       channels,
                       depth,
                       reduced_size,
                       out_channels,
                       kernel_size).double()

In [None]:
from triplet_loss import PNTripletLoss
# Training parameters
batch_size = 16

epochs = 30
lr = 0.001

loss_function = PNTripletLoss()
optimizer = torch.optim.Adam(causal_cnn.parameters(), lr=lr)

In [None]:
from utils import Dataset
train_dataset = Dataset(X_train)
train_generator = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)

In [None]:
# Encoder training
torch.manual_seed(0)
np.random.seed(0)
history = []
for i in range(epochs):
    for batch in train_generator:
        optimizer.zero_grad()
        # No model call here, that is done in the loss function directly
        loss = loss_function(batch, causal_cnn)
        loss.backward()
        optimizer.step()
    print("Epoch", i+1, loss)
    history.append(loss)
history = torch.tensor(history)
plt.plot(history.unsqueeze(dim=1))
plt.xticks(np.arange(epochs))
plt.xlabel("Epoch")
plt.ylabel("Triplet Loss");

# Text