<center><img src="../img/img_0.PNG"  width="1000" height="240"/></center>


This notebook demonstrate an apllication of SimMTM, a simple self-supervised learning framework for time series modeling. Self-supervised learning is a learning paradigm that allows model to learn a good representation from the input data itself. The learned representation will be beneficial to some downstream tasks such as forecasting, classification and outlier detection. 

Self-supervised learning has a lof of success and achieves state-of-the-art performance in some domains, especially in the image domain. In this demo, we will show a self-supervised learning method, SimMTM, in the time-series domain. SimMTM adopts both masked modeling and contrastive modeling to learn a good representation of the input data. By using the learned representation and finetuning it, we achieve a significant improvement compared to the model without self-supervised learning. 

In [1]:
import os
try:
    os.chdir('src')
except:
    pass
print(os.getcwd())

/home/shamvinc/ssl_time_series/mvts_transformer/src


In [2]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from tqdm import tqdm
import copy

import math

from datasets.datasplit import split_dataset
from datasets.data import data_factory, Normalizer, TSRegressionArchive, CSVRegressionArchive
from datasets.datasplit import split_dataset
from datasets.dataset import collate_superv
from models.ts_transformer import model_factory
from models.loss import get_loss_module, contrastive_loss
from optimizers import get_optimizer

from options import Options
from running import setup


# Masked Modeling

Self-supervision via a ‘pretext task’ on input data combined with finetuning on labeled data is widely used for improving model performance in language and computer
vision. One of the popular self-supervision tasks on language data is masked modeling. Masking modeling is to mask some of the input entries randomly and predict those masked entries by using unmasked entries. By masked modeling, the model can learn the relationship through different features and different timesteps. 

<img src="../img/img_1.PNG"  width="900" height="240"/>

<img src="../img/img_2.PNG"  width="900" height="240"/>

# Masking Choice
### Random Masking

Random Masking is not a good choice to learn a good representation because the model can simply learn to take the average from the neighbour values. 



<img src="../img/img_3.PNG" width="600"/>

### Geometric Masking

Instead, we choose to use the geometric masking method, which is to mask a sequence of the input data randomly. The length of the sequence is followed by a geometric distribution. In this case, the model requires to recover a masked sequence from other unmasked input data. We suggest the expected length of a masked sequence is a half of the whole time series sequence.

In [4]:
def geom_noise_mask_single(L, lm, masking_ratio):
    """
    Randomly create a boolean mask of length `L`, consisting of subsequences of average length lm, masking with 0s a `masking_ratio`
    proportion of the sequence L. The length of masking subsequences and intervals follow a geometric distribution.
    Args:
        L: length of mask and sequence to be masked
        lm: average length of masking subsequences (streaks of 0s)
        masking_ratio: proportion of L to be masked

    Returns:
        (L,) boolean numpy array intended to mask ('drop') with 0s a sequence of length L
    """
    keep_mask = np.ones(L, dtype=bool)
    p_m = 1 / lm  # probability of each masking sequence stopping. parameter of geometric distribution.
    p_u = p_m * masking_ratio / (1 - masking_ratio)  # probability of each unmasked sequence stopping. parameter of geometric distribution.
    p = [p_m, p_u]

    # Start in state 0 with masking_ratio probability
    state = int(np.random.rand() > masking_ratio)  # state 0 means masking, 1 means not masking
    for i in range(L):
        keep_mask[i] = state  # here it happens that state and masking value corresponding to state are identical
        if np.random.rand() < p[state]:
            state = 1 - state

    return keep_mask

# SimMTM ultilizes both contrastive learning and mask modeling to learn the data representation.
## 1 - Contrastive Learning

when we mask the input time series data, we create many masked views of the input data. We expect that the distance between two views of the same time series sequence is minimized while maximizing the distance between two different sequences.

<img src="../img/img_5.png"/>

## The contrastive loss is the following: (Eq. 8 in the paper)

<center><img src="../img/img_6.PNG"/><center/>

In [5]:
def demo_contrastive_loss(s, batch_size, tau=0.05):
    s = s.squeeze(-1) 

    B = s.shape[0]
    v = s.reshape(B, -1)

    norm_v = torch.norm(v, p=2, dim=-1).unsqueeze(-1)
    v = v/norm_v
    u = torch.transpose(v, 0, 1)

    R = torch.matmul(v,u)

 
    R = torch.exp(R/tau) # (batch + mask size) x (batch + mask size)
    
    # number of masks
    M = B//batch_size
    mask = torch.eye(batch_size, device=R.device).repeat_interleave(M,dim=0).repeat_interleave(M,dim=1)

    denom = R * (torch.ones_like(R) - torch.eye(R.shape[0], device=R.device))

    denom = R.sum(-1).unsqueeze(-1)

    loss = torch.log(R/denom)
    

    loss = (loss * (mask - torch.eye(R.shape[0], device=R.device))).sum(1)/(M-1) # except no masked unit
    loss = loss.mean(0)
    
    return -loss


## 2 - Masked Modeling

SimMTM proposes to recover a time serie by the weighted sum of multiple masked points, which eases the reconstruction task by assembling ruined but complementary temporal variations.

<img src="../img/img_4.png"/>

In [6]:
from models.ts_transformer import LearnablePositionalEncoding, TransformerBatchNormEncoderLayer

class DemoSimMTMTransformerEncoder(nn.Module):
    
    def __init__(self, max_len, feat_dim, out_len, out_dim, d_model=16, n_heads=4, num_layers=2, dim_feedforward=32, dropout=0.2, temporal_unit=3):
        super(DemoSimMTMTransformerEncoder, self).__init__()

        self.max_len = max_len
        self.d_model = d_model
        self.n_heads = n_heads
        
        self.tau = 0.05
        self.mask_length = max_len//2
        self.mask_rate = 0.5

        self.project_inp = nn.Linear(feat_dim, d_model)
        self.projector_layer = nn.Linear(max_len, 1)
        self.pos_enc1 = LearnablePositionalEncoding(d_model, dropout=dropout, max_len=max_len)
        self.pos_enc2 = LearnablePositionalEncoding(d_model, dropout=dropout, max_len=out_len)
        
        self.act = F.gelu 

        # encoder_layer = nn.TransformerEncoderLayer(d_model, self.n_heads, dim_feedforward, dropout, activation='gelu')
        encoder_layer = TransformerBatchNormEncoderLayer(d_model, self.n_heads, dim_feedforward, dropout, activation='gelu')

        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)

        self.output_layer = nn.Linear(d_model, feat_dim)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout2d(dropout)

        # self.predict_layer1 = nn.Conv1d(d_model, 512, 5, stride=1)
        self.predict_layer1 = nn.Linear(max_len, out_len)
        self.predict_layer2 = nn.Linear(d_model, out_dim)
        # self.bn = nn.BatchNorm1d(d_model)

        self.feat_dim = feat_dim

        self.temporal_unit = temporal_unit

        self.w1 = torch.nn.parameter.Parameter(data=torch.ones(1), requires_grad=True)
        self.w2 = torch.nn.parameter.Parameter(data=torch.ones(1), requires_grad=True)
        
        
    def forward(self, X):
        """
        Reconstruct the input and create the projected output of X
        
        Args:
            X: (batch_size, seq_length, feat_dim) torch tensor of original input

        Returns:
            output: (batch_size, seq_length, feat_dim)
            s: (batch_size, d_model, 1)
        """

        
        _x = X
         
        # Create masked views of the input X
        for i in range(self.temporal_unit):
            mask = geom_noise_mask_single(X.shape[0] * X.shape[1] * X.shape[2], self.mask_length, self.mask_rate)
            mask = mask.reshape(X.shape[0], X.shape[1], X.shape[2])
            mask = torch.from_numpy(mask).to(X.device)
            x_masked = mask * X
            _x = torch.cat([_x, x_masked], axis=-1) # [batch_size, seq_length, feat_dim * temporal_unit]
    
        
        _x = _x.reshape(X.shape[0] * (self.temporal_unit + 1), X.shape[1], X.shape[2])
  

        inp = _x.permute(1, 0, 2)
        inp = self.project_inp(inp) * np.sqrt(self.d_model)  # [seq_length, batch_size, d_model] project input vectors to d_model dimensional space
        inp = self.pos_enc1(inp)  # add positional encoding

        
        output = self.transformer_encoder(inp)  # (seq_length, batch_size, d_model)
        output = self.act(output)  # the output transformer encoder/decoder embeddings don't include non-linearity
        output = output.permute(1, 0, 2)  # (batch_size, seq_length, d_model)
        output = self.dropout1(output)

        z_hat, _s = self.project(output, self.tau)
        # Most probably defining a Linear(d_model,feat_dim) vectorizes the operation over (seq_length, batch_size).
        output = self.output_layer(z_hat)  # (batch_size, seq_length, feat_dim)

        return output, _s

    
    def project(self, z, tau):
        """
        Output a weighted average of z
        
        Args:
            X: (batch_size, seq_length, feat_dim) torch tensor of original input

        Returns:
            z_hat: (batch_size, seq_length, d_model)
            s: (batch_size, d_model, 1)
        """
        _z = z.transpose(1, 2) # [batch_size, d_model, seq_length]
        _s = s = self.projector_layer(_z) # [batch_size, d_model, 1]
        
        if self.training:
            mask = torch.ones(1, self.d_model, 1).to(z.device)
            mask = self.dropout3(mask)
            s = s * mask 
            s = s + torch.randn(s.shape).to(z.device) * 1e-2
        
        
        s = s.squeeze(-1) 
        B = s.shape[0]
        v = s.reshape(B, -1)

        norm_v = torch.norm(v, p=2, dim=-1).unsqueeze(-1)
        v = v/norm_v
        u = torch.transpose(v, 0, 1)
        
        R = torch.matmul(v,u)
     
  
        R = torch.exp(R/tau) # (batch + mask size) x (batch + mask size)
        R = R * (torch.ones_like(R) - torch.eye(R.shape[0], device=R.device)) # zero out the weight of no masked component
        R = R/R.sum(-1).unsqueeze(-1)
        M = self.temporal_unit + 1
        R = R[::M] # extract every no mask unit # (batch size) x (batch + mask size)

        z_hat = (R.unsqueeze(-1).unsqueeze(-1) * z.unsqueeze(0)).sum(1) 
        return z_hat, _s


    def predict(self, X):
        """
        Predict an output given X
        
        Args:
            z: (batch_size, seq_length, d_model) torch tensor of representations of input
            tau: temperture of similarity matrix

        Returns:
            output: (batch_size, out_seq_len, out_dim)
        """
        
        # permute because pytorch convention for transformers is [seq_length, batch_size, feat_dim]. padding_masks [batch_size, feat_dim]
        inp = X.permute(1, 0, 2)
        inp = self.project_inp(inp) * np.sqrt(self.d_model)  # [seq_length, batch_size, d_model] project input vectors to d_model dimensional space
        inp = self.pos_enc1(inp)  # add positional encoding
        # NOTE: logic for padding masks is reversed to comply with definition in MultiHeadAttention, TransformerEncoderLayer

        output = self.transformer_encoder(inp)
        output = output.permute(1, 0, 2)  # (batch_size, seq_length, d_model)
        # output = self.dropout1(output)
       
        output = output.transpose(1, 2) # (batch_size, d_model, seq_length)
        output = self.predict_layer1(output)
        # output = self.act(output)
        
        output = output.transpose(1, 2) # (batch_size, seq_length, d_model)
        output = output.permute(1, 0, 2)
        output = self.pos_enc2(output)
        output = output.permute(1, 0, 2)
        output = self.dropout2(output)
        output = self.predict_layer2(output) 
        return output



# Data Loading and Preparation

In this demo, we use a benchmask time series dataset called ETT, which contains the time series of oil temperature and power load collected by electricity
transformers from July 2016 to July 2018. ETT is a group of four subsets with different recorded frequencies: ETTh1/ETTh2 are recorded every hour, and ETTm1/ETTm2 are recorded every 15 minutes.

In [7]:
args = Options().parse()  
args.data_dir = '../datasets/ETTh1'
args.task = 'regression'
args.output_dir = '../experiments'
config = setup(args)
from datasets.data import CSVRegressionArchive
data = CSVRegressionArchive(config['data_dir'], pattern='TRAIN', config=config)
_data = data

# Standard Normalization
normalizer = Normalizer(config['normalization'])
data.feature_df = normalizer.normalize(data.feature_df)
data.labels_df = data.feature_df

train_slice = slice(None, 12*30*24)
val_slice = slice(12*30*24, 16*30*24)
test_slice = slice(16*30*24, 20*30*24)
                

2023-08-24 07:27:01,815 | INFO : Stored configuration file in '../experiments/_2023-08-24_07-27-00_hhL'


In [8]:
data.feature_df

Unnamed: 0,HUFL,HULL,MUFL,MULL,LUFL,LULL,OT
0,-0.219043,-0.114203,-0.395671,-0.231896,0.976327,0.805715,2.008455
1,-0.238003,-0.081398,-0.411344,-0.251793,0.923944,0.857420,1.688155
2,-0.313840,-0.245425,-0.442544,-0.291035,0.610506,0.602230,1.688155
3,-0.323320,-0.147009,-0.442544,-0.271138,0.636268,0.703972,1.367970
4,-0.285401,-0.147009,-0.411344,-0.231896,0.688651,0.703972,1.006581
...,...,...,...,...,...,...,...
17415,-1.280344,0.640322,-1.452361,0.691116,0.348592,1.110943,-0.282559
17416,-1.820544,1.001183,-1.967523,0.769600,0.400975,1.364466,-0.266218
17417,-0.645488,0.771544,-0.749561,0.671772,0.558123,1.110943,-0.356448
17418,0.264279,0.771544,0.171637,0.671772,0.505741,0.959163,-0.413995


In [9]:
max_len = 72
out_size = 24
out_len = 7
# config['data_window_len'] = max_len
# config['task'] = 'simmtm'
# config['normalization_layer'] = 'BatchNorm'
# config['out_len'] = 24
# config['out_dim'] = 7
# config['d_model'] = 16
# config['dim_feedforward'] = 128
# config['num_heads'] = 4
# config['num_layers'] = 1
# from models.ts_transformer import model_factory
# model = model_factory(config, data)
model = DemoSimMTMTransformerEncoder(max_len=max_len, feat_dim=data.feature_df.shape[1], out_len=out_size, out_dim=out_len, 
                                     d_model=4, n_heads=4, num_layers=2, dim_feedforward=8)

device = "cuda"
model.to(device)
model.tau = 0.05
model.mask_length = max_len//2
model.mask_ratio = 0.5
model.temporal_unit = 3

In [10]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [11]:
from torch.utils.data import DataLoader
batch_size = 64


train_indices = data.feature_df[train_slice].index.values[:-max_len-out_size]
val_indices = data.feature_df[val_slice].index.values[:-max_len-out_size]
test_indices = data.feature_df[test_slice].index.values[:-max_len-out_size]

train_dataloader = DataLoader(train_indices, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_indices, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_indices, batch_size=batch_size, shuffle=True)



# Self-Supervised Learning Training Loop

In [12]:
i = 0
max_epoch = 10
best_loss = 1e10
best_epoch = 0
device = "cuda"
loss_fn = nn.MSELoss()
best_model = copy.deepcopy(model)

while i < max_epoch:
    train_loss = { "loss": [], "loss_mse": [], "loss_con": []}
    progress_bar = tqdm(train_dataloader)
    
    for IDs in progress_bar:
        model.train()
        X = list(map(lambda idx: np.expand_dims(data.feature_df.loc[idx:idx+max_len-1].to_numpy(), 0), IDs))
        X = np.concatenate(X, axis=0)
        X = torch.tensor(X).to(device)
        X = X.float()
        X = X.reshape(X.shape[0], max_len, -1)
        # X = X[:, :, -1:]
        
        pred, s = model(X)  # (batch_size, padded_length, feat_dim)
        
        loss_mse = loss_fn(pred, X) 

        loss_con = demo_contrastive_loss(s, X.shape[0])

        loss = 1/(model.w1.pow(2)) * loss_mse + 1/(model.w2.pow(2)) * loss_con + torch.log(model.w1) + torch.log(model.w2)
  


        optimizer.zero_grad()
        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=4.0)
        optimizer.step()
        # import ipdb; ipdb.set_trace()
        progress_bar.set_description("Epoch {0} - Training loss: {1:.2f} - MSE loss: {2:.2f} - Contrastive loss: {3:.2f}".format(i, 
                loss.cpu().detach().numpy().item(), loss_mse.cpu().detach().numpy().item(), loss_con.cpu().detach().numpy().item())) 
        train_loss["loss"].append(loss)
        train_loss["loss_mse"].append(loss_mse)
        train_loss["loss_con"].append(loss_con)
    
            
    with torch.no_grad():
        val_loss = { "loss": [], "loss_mse": [], "loss_con": []}
        for IDs in val_dataloader:
            model.eval()
            X = list(map(lambda idx: np.expand_dims(data.feature_df.loc[idx:idx+max_len-1].to_numpy(), 0), IDs))
            X = np.concatenate(X, axis=0)
            X = torch.tensor(X).to(device)
            X = X.float()
            X = X.reshape(X.shape[0], max_len, -1)
            # X = X[:, :, -1:]


            pred, s = model(X)  # (batch_size, padded_length, feat_dim)
        
            loss_mse = loss_fn(pred, X) 

            loss_con = demo_contrastive_loss(s, X.shape[0])

            loss = 1/(model.w1.pow(2)) * loss_mse + 1/(model.w2.pow(2)) * loss_con + torch.log(model.w1) + torch.log(model.w2)

            val_loss["loss"].append(loss)
            val_loss["loss_mse"].append(loss_mse)
            val_loss["loss_con"].append(loss_con)

        train_loss["loss"] = torch.tensor(train_loss["loss"]).mean()
        train_loss["loss_mse"] = torch.tensor(train_loss["loss_mse"]).mean()
        train_loss["loss_con"] = torch.tensor(train_loss["loss_con"]).mean()
        val_loss["loss"] = torch.tensor(val_loss["loss"]).mean()
        val_loss["loss_mse"] = torch.tensor(val_loss["loss_mse"]).mean()
        val_loss["loss_con"] = torch.tensor(val_loss["loss_con"]).mean()

        if val_loss["loss"] < best_loss:
            best_loss = val_loss["loss"]
            best_model = copy.deepcopy(model)
            best_epoch = i
    
        progress_bar.write("Epoch {0} - Training loss: {1:.2f} {2:.2f} {3:.2f} - Validation loss: {4:.2f} {5:.2f} {6:.2f}".format(i, 
            train_loss["loss"].cpu().detach().numpy().item(), train_loss["loss_mse"].cpu().detach().numpy().item(), train_loss["loss_con"].cpu().detach().numpy().item(),
            val_loss["loss"].cpu().detach().numpy().item(), val_loss["loss_mse"].cpu().detach().numpy().item(), val_loss["loss_con"].cpu().detach().numpy().item()))
    i += 1
    
    
tqdm.write("Best Epoch {} - Best Validation loss: {}".format(best_epoch, best_loss))

Epoch 0 - Training loss: 3.76 - MSE loss: 0.91 - Contrastive loss: 3.72: 100%|██████████████████████████████████████████████████████████████████████| 134/134 [00:10<00:00, 12.52it/s]
Epoch 1 - Training loss: 4.50 - MSE loss: 1.09 - Contrastive loss: 4.53:   1%|█                                                                       | 2/134 [00:00<00:10, 13.02it/s]

Epoch 0 - Training loss: 6.40 1.07 5.97 - Validation loss: 4.97 1.37 4.87


Epoch 1 - Training loss: 3.29 - MSE loss: 1.15 - Contrastive loss: 3.45: 100%|██████████████████████████████████████████████████████████████████████| 134/134 [00:10<00:00, 12.98it/s]
Epoch 2 - Training loss: 3.53 - MSE loss: 0.95 - Contrastive loss: 4.07:   1%|█                                                                       | 2/134 [00:00<00:10, 12.94it/s]

Epoch 1 - Training loss: 3.85 0.94 4.20 - Validation loss: 4.27 1.30 4.92


Epoch 2 - Training loss: 2.78 - MSE loss: 0.71 - Contrastive loss: 3.43: 100%|██████████████████████████████████████████████████████████████████████| 134/134 [00:10<00:00, 12.95it/s]
Epoch 3 - Training loss: 3.19 - MSE loss: 0.92 - Contrastive loss: 3.99:   1%|█                                                                       | 2/134 [00:00<00:10, 12.78it/s]

Epoch 2 - Training loss: 3.32 0.85 4.07 - Validation loss: 3.87 1.27 4.88


Epoch 3 - Training loss: 2.66 - MSE loss: 0.71 - Contrastive loss: 3.49: 100%|██████████████████████████████████████████████████████████████████████| 134/134 [00:10<00:00, 12.91it/s]
Epoch 4 - Training loss: 2.89 - MSE loss: 0.74 - Contrastive loss: 3.95:   1%|█                                                                       | 2/134 [00:00<00:10, 13.14it/s]

Epoch 3 - Training loss: 3.03 0.77 4.02 - Validation loss: 3.66 1.25 4.93


Epoch 4 - Training loss: 2.32 - MSE loss: 0.52 - Contrastive loss: 3.24: 100%|██████████████████████████████████████████████████████████████████████| 134/134 [00:10<00:00, 12.92it/s]
Epoch 5 - Training loss: 2.71 - MSE loss: 0.61 - Contrastive loss: 4.04:   1%|█                                                                       | 2/134 [00:00<00:10, 13.03it/s]

Epoch 4 - Training loss: 2.82 0.70 4.00 - Validation loss: 3.55 1.22 5.06


Epoch 5 - Training loss: 2.39 - MSE loss: 0.59 - Contrastive loss: 3.50: 100%|██████████████████████████████████████████████████████████████████████| 134/134 [00:10<00:00, 12.94it/s]
Epoch 6 - Training loss: 2.57 - MSE loss: 0.53 - Contrastive loss: 4.08:   1%|█                                                                       | 2/134 [00:00<00:10, 13.08it/s]

Epoch 5 - Training loss: 2.65 0.64 3.97 - Validation loss: 3.54 1.21 5.31


Epoch 6 - Training loss: 2.24 - MSE loss: 0.60 - Contrastive loss: 3.22: 100%|██████████████████████████████████████████████████████████████████████| 134/134 [00:10<00:00, 12.89it/s]
Epoch 7 - Training loss: 2.55 - MSE loss: 0.64 - Contrastive loss: 4.00:   1%|█                                                                       | 2/134 [00:00<00:10, 12.43it/s]

Epoch 6 - Training loss: 2.52 0.59 3.93 - Validation loss: 3.51 1.20 5.44


Epoch 7 - Training loss: 2.30 - MSE loss: 0.62 - Contrastive loss: 3.47: 100%|██████████████████████████████████████████████████████████████████████| 134/134 [00:10<00:00, 12.87it/s]
Epoch 8 - Training loss: 2.41 - MSE loss: 0.55 - Contrastive loss: 3.98:   1%|█                                                                       | 2/134 [00:00<00:10, 12.92it/s]

Epoch 7 - Training loss: 2.43 0.57 3.91 - Validation loss: 3.47 1.18 5.55


Epoch 8 - Training loss: 2.26 - MSE loss: 0.63 - Contrastive loss: 3.43: 100%|██████████████████████████████████████████████████████████████████████| 134/134 [00:10<00:00, 12.74it/s]
Epoch 9 - Training loss: 2.30 - MSE loss: 0.50 - Contrastive loss: 3.92:   1%|█                                                                       | 2/134 [00:00<00:10, 12.80it/s]

Epoch 8 - Training loss: 2.37 0.57 3.89 - Validation loss: 3.37 1.17 5.50


Epoch 9 - Training loss: 2.12 - MSE loss: 0.57 - Contrastive loss: 3.24: 100%|██████████████████████████████████████████████████████████████████████| 134/134 [00:10<00:00, 12.61it/s]


Epoch 9 - Training loss: 2.31 0.56 3.86 - Validation loss: 3.35 1.16 5.61
Best Epoch 9 - Best Validation loss: 3.348235845565796


# Finetune Training Loop

Our downstream task is to forecast the next 24 hours time series data given the past 72 hours time series data. To do this, we append a linear layer on our learned representation, and finetune it for our downstream task.

In [13]:
finetune_model = copy.deepcopy(best_model)
optimizer = torch.optim.AdamW(finetune_model.parameters(), lr=1e-3)

In [14]:
i = 0
max_epoch = 10
best_loss = 1e10
best_finetune_model = copy.deepcopy(best_model)
best_epoch = 0
device = "cuda"
finetune_model.to(device)
while i < max_epoch:
    train_loss = []
    progress_bar = tqdm(train_dataloader)
    
    for IDs in progress_bar:
        finetune_model.train()
        
        X = list(map(lambda idx: np.expand_dims(data.feature_df.loc[idx:idx+max_len-1].to_numpy(), 0), IDs))
        X = np.concatenate(X, axis=0)
        X = torch.tensor(X).to(device)
        X = X.float()
        X = X.reshape(X.shape[0], max_len, -1)
        # X = X[:, :, -1:]
        
        targets =  list(map(lambda idx: np.expand_dims(data.labels_df.loc[idx+max_len:idx+max_len+out_size-1].to_numpy(), 0), IDs))
        targets = np.concatenate(targets, axis=0)
        # targets = torch.tensor(targets[:,:,-1]).to(device)
        targets = torch.tensor(targets).to(device)
        targets = targets.float()
        targets = targets.reshape(X.shape[0], out_size, -1)
        # targets = targets[:, :, -1:]

        pred = finetune_model.predict(X.float())
        pred = pred.reshape(X.shape[0], out_size, -1)
        loss = loss_fn(pred, targets)


        optimizer.zero_grad()
        loss.backward()

        nn.utils.clip_grad_norm_(finetune_model.parameters(), max_norm=4.0)
        optimizer.step()

        progress_bar.set_description("Epoch {} - Training loss: {:.2f}".format(i, loss)) 
        train_loss.append(loss)
    
    with torch.no_grad():
        val_loss = []
        for IDs in val_dataloader:
            finetune_model.eval()
            X = list(map(lambda idx: np.expand_dims(data.feature_df.loc[idx:idx+max_len-1].to_numpy(), 0), IDs))
            X = np.concatenate(X, axis=0)
            X = torch.tensor(X).to(device)
            X = X.float()
            X = X.reshape(X.shape[0], max_len, -1)
            # X = X[:, :, -1:]
            
            targets =  list(map(lambda idx: np.expand_dims(data.labels_df.loc[idx+max_len:idx+max_len+out_size-1].to_numpy(), 0), IDs))
            targets = np.concatenate(targets, axis=0)
            # targets = torch.tensor(targets[:,:,-1]).to(device)
            targets = torch.tensor(targets).to(device)
            targets = targets.float()
            targets = targets.reshape(X.shape[0], out_size, -1)
            # targets = targets[:, :, -1:]

            pred = finetune_model.predict(X.float())
            pred = pred.reshape(X.shape[0], out_size, -1)
            
            loss = loss_fn(pred, targets)
            val_loss.append(loss)

        train_loss = torch.tensor(train_loss).mean()
        val_loss = torch.tensor(val_loss).mean()

        if val_loss < best_loss:
            best_loss = val_loss
            best_finetune_model = copy.deepcopy(finetune_model)
            best_epoch = i
    
    progress_bar.write("Epoch {} - Training loss: {:.2f} - Validation loss: {:.2f}".format(i, train_loss, val_loss))
    i += 1
    
    
tqdm.write("Best Epoch {} - Best Validation loss: {}".format(best_epoch, best_loss))

Epoch 0 - Training loss: 0.73: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:04<00:00, 29.01it/s]
Epoch 1 - Training loss: 0.69:   2%|██▌                                                                                                               | 3/134 [00:00<00:04, 28.85it/s]

Epoch 0 - Training loss: 0.82 - Validation loss: 1.11


Epoch 1 - Training loss: 0.62: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:04<00:00, 28.76it/s]
Epoch 2 - Training loss: 0.57:   2%|██▌                                                                                                               | 3/134 [00:00<00:04, 28.90it/s]

Epoch 1 - Training loss: 0.64 - Validation loss: 1.01


Epoch 2 - Training loss: 0.56: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:04<00:00, 29.05it/s]
Epoch 3 - Training loss: 0.60:   2%|██▌                                                                                                               | 3/134 [00:00<00:04, 29.50it/s]

Epoch 2 - Training loss: 0.59 - Validation loss: 0.92


Epoch 3 - Training loss: 0.65: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:04<00:00, 29.62it/s]
Epoch 4 - Training loss: 0.54:   2%|██▌                                                                                                               | 3/134 [00:00<00:04, 29.34it/s]

Epoch 3 - Training loss: 0.56 - Validation loss: 0.84


Epoch 4 - Training loss: 0.57: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:04<00:00, 27.89it/s]
Epoch 5 - Training loss: 0.53:   2%|██▌                                                                                                               | 3/134 [00:00<00:04, 28.54it/s]

Epoch 4 - Training loss: 0.53 - Validation loss: 0.82


Epoch 5 - Training loss: 0.53: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:04<00:00, 28.90it/s]
Epoch 6 - Training loss: 0.51:   2%|██▌                                                                                                               | 3/134 [00:00<00:04, 28.98it/s]

Epoch 5 - Training loss: 0.52 - Validation loss: 0.78


Epoch 6 - Training loss: 0.49: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:04<00:00, 28.26it/s]
Epoch 7 - Training loss: 0.47:   2%|██▌                                                                                                               | 3/134 [00:00<00:05, 25.76it/s]

Epoch 6 - Training loss: 0.51 - Validation loss: 0.75


Epoch 7 - Training loss: 0.51: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:05<00:00, 25.65it/s]
Epoch 8 - Training loss: 0.49:   2%|██▌                                                                                                               | 3/134 [00:00<00:04, 29.59it/s]

Epoch 7 - Training loss: 0.50 - Validation loss: 0.74


Epoch 8 - Training loss: 0.55: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:04<00:00, 29.35it/s]
Epoch 9 - Training loss: 0.48:   2%|██▌                                                                                                               | 3/134 [00:00<00:04, 28.61it/s]

Epoch 8 - Training loss: 0.50 - Validation loss: 0.73


Epoch 9 - Training loss: 0.53: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:04<00:00, 27.32it/s]


Epoch 9 - Training loss: 0.50 - Validation loss: 0.72
Best Epoch 9 - Best Validation loss: 0.7233729362487793


In [15]:
test_loss = []
with torch.no_grad():
    for IDs in test_dataloader:
        best_finetune_model.eval()
        X = list(map(lambda idx: np.expand_dims(data.feature_df.loc[idx:idx+max_len-1].to_numpy(), 0), IDs))
        X = np.concatenate(X, axis=0)
        X = torch.tensor(X).to(device)
        X = X.float()
        X = X.reshape(X.shape[0], max_len, -1)
        # X = X[:, :, -1:]
        

        targets =  list(map(lambda idx: np.expand_dims(data.labels_df.loc[idx+max_len:idx+max_len+out_size-1].to_numpy(), 0), IDs))
        targets = np.concatenate(targets, axis=0)
        # targets = torch.tensor(targets[:,:,-1]).to(device)
        targets = torch.tensor(targets).to(device)
        targets = targets.float()
        targets = targets.reshape(X.shape[0], out_size, -1)
        # targets = targets[:, :, -1:]
        

        pred = best_finetune_model.predict(X.float())
        pred = pred.reshape(X.shape[0], out_size, -1)
        loss = loss_fn(pred, targets)


        test_loss.append(loss)


test_loss = torch.tensor(test_loss).mean()
print("Test MSE loss: {}".format(test_loss))
print("Test RMSE loss: {}".format(np.sqrt(test_loss)))

Test loss: 0.5565699934959412


Reference:
1. https://arxiv.org/abs/2302.00861
2. https://github.com/gzerveas/mvts_transformer