In [None]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [None]:
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim

from datasets import Dataset
from functools import partial
from transformers import TimeSeriesTransformerConfig, TimeSeriesTransformerForPrediction

from src.utils import transform_start_field
from src import ts_transformer as tsf
from src.inference.wrapper import TFWrapper
from src.ts_transformer import create_train_dataloader

from src.networks.classifier import RepresentationClassifier, train_classifier

# Get Dataset

In [None]:
#assumes a dataset <data>

train_df, test_df = train_test_split(data, test_size=0.2, random_state=42)
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)
print(train_df.shape)

# Load Model

In [None]:
# Loading a pretrained model, for training check train_test.ipynb

freq = '1H'
transformer = TimeSeriesTransformerForPrediction.from_pretrained(
    "<PATH_TO_WEIGHTS>")
model = TFWrapper(transformer, freq)

# Classifier

1. forward pass of all data in trained transformer
2. extract latent space
3. add corresponding label
4. Jointly train projection network and classifier

In [None]:
train_df, valid_df = train_test_split(train_df, test_size=0.2, random_state=42)
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)
print(train_df.shape)

In [None]:
train_data = Dataset.from_pandas(train_df, preserve_index=False)
train_data.set_transform(partial(transform_start_field, freq=freq))

valid_data = Dataset.from_pandas(valid_df, preserve_index=False)
valid_data.set_transform(partial(transform_start_field, freq=freq))

In [None]:
train_dataloader = create_train_dataloader(
                            config=model.model_config,
                            freq=model.freq,
                            data=train_data,
                            batch_size=32,
                            num_batches_per_epoch=16)

valid_dataloader = create_train_dataloader(
                            config=model.model_config,
                            freq=model.freq,
                            data=valid_data,
                            batch_size=32,
                            num_batches_per_epoch=16)

In [None]:
num_epochs = 50

# we classify the first static feature
num_classes = len(list(set([i[0] for i in train_df['feat_static_cat']])))

classifier = RepresentationClassifier(
    encoder_hidden_size=model.model_config.d_model,
    attn_hidden_dims= None, #[64, 32],
    classifier_hidden_dims= None, #[64, 32],
    num_classes=num_classes,
    attn_activation=nn.Tanh,
    classifier_activation=nn.ReLU,
    attn_dropout=0,
    classifier_dropout=0,
)

optimizer = optim.Adam(classifier.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [None]:
classifier, train_losses, valid_losses = train_classifier(
    model=model.transformer,
    classifier=classifier,
    train_loader=train_dataloader,
    optimizer=optimizer,
    criterion=criterion,
    device=model.transformer.device,
    num_epochs=num_epochs,
    valid_loader=valid_dataloader)

In [None]:
plt.figure(figsize=(6, 4))
plt.plot(train_losses, label='Classifier Training Loss')
plt.plot(valid_losses, label='Classifier Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.legend()
plt.show()