<a href="https://colab.research.google.com/github/whoami-Lory271/thesis-project/blob/main/thesis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Symbol legend

* B: batch size 
* M: number of channel
* P: patch dimension
* N: number of patches
* L: lookback window


# Installations and imports


In [1]:
!pip install pytorch-lightning==2.0.1.post0 --quiet
!pip install einops==0.6.1 --quiet
!pip install ipdb --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m718.6/718.6 kB[0m [31m43.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m49.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m23.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m28.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m149.6/149.6 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.5/114.5 kB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m74.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import pandas as pd
import logging
# https://github.com/gotcha/ipdb
import ipdb
import copy
from google.colab import drive
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
# https://theaisummer.com/einsum-attention/
import einops
import math
from sklearn.model_selection import train_test_split
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.callbacks import EarlyStopping

In [3]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Constants

In [4]:
# logger
LOG_LEVEL = logging.DEBUG

# datasets
ELECTRICITY = "electricity"

# models
CoPST = "CoPST"

#paths
ROOT_FOLDER = "/content/drive/MyDrive/Università/Magistrale/Tesi/code"
MODEL_FOLDER = ROOT_FOLDER + "/models"
CHECKPOINT_FOLDER = ROOT_FOLDER + "/checkpoints"
LOGS_FOLDER = ROOT_FOLDER + "/logs"

#hyperparameters
BATCH_SIZE = 4
STRIDE = 4
PATCH_LEN = 8
L = 32
PATCH_NUM = 8
M = 370 # TODO: generilize for all datasets
DICT_MOMENTUM_SIZE = 100

# training
DETERMINISTIC = True
LOAD_MODEL = False

# Logger

In [5]:
# create logger
log = logging.getLogger('APP')
log.setLevel(LOG_LEVEL)

# # create console handler and set level to debug
# ch = logging.StreamHandler()
# ch.setLevel(logging.INFO)

# # create formatter
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# # add formatter to ch
# ch.setFormatter(formatter)

# # add ch to logger
# logger.addHandler(ch)

In [6]:
# setup logger function
def setup_log(self, level):
    log = logging.getLogger(self.__class__.__name__)
    log.setLevel(level)
    return log

# Preprocessing

## Datasets

In [7]:
datasets_path = {
    ELECTRICITY: ROOT_FOLDER + "/datasets/electricity"
}

datasets_name = {
    ELECTRICITY: "/LD2011_2014.txt"    
}
datasets_processed_name = {
    ELECTRICITY: "/electricity.pkl"
}

### Electricity

**Preprocessing**

In [8]:
# df = pd.read_csv(datasets_path[ELECTRICITY] + datasets_name[ELECTRICITY], sep = ';')
# df.rename(columns={df.columns[0]: 'Date'},inplace=True)
# df.to_pickle(datasets_path[ELECTRICITY] + datasets_processed_name[ELECTRICITY])

In [9]:
# df = pd.read_pickle(datasets_path[ELECTRICITY] + datasets_processed_name[ELECTRICITY])

In [10]:
class ElectricityDataset(Dataset):
    def __init__(self, data, look_window, p = 0.5, multiplier = 10):
        super().__init__()
        self.data = data
        self.look_window = look_window
        self.p = p
        self.multiplier = multiplier
        self.epsilon = torch.empty(1).normal_(mean = 0, std = 0.5)

    def __len__(self):
        return self.data.shape[0] // self.look_window

    def __getitem__(self, idx):
        # rescale the index
        start = idx * self.look_window
        end = start + self.look_window
        # tale all the channels and skip the date
        ts = self.data.iloc[start:end, 1:].astype('str').replace(',','.', regex = True).astype('float32') #.astype(str).str.replace(',', '.').astype('float32')
        # convert in a tensor
        ts = torch.from_numpy(ts.values) # [L x M]
        return self.transform(ts), self.transform(ts)

    def get_len(self):
        return self.__len__()

    # def get_channels(self):
    #     return self.data.iloc[0, 1:].astype(str).str.replace(',', '.').astype('float32').shape[0]
    
    def transform(self, x):
        return self.jitter(self.shift(self.scale(x)))

    def jitter(self, x):
        if random.random() > self.p:
            return x
        return x + torch.empty(1).normal_(mean = 0, std = 0.5)

    def scale(self, x):
        if random.random() > self.p:
            return x
        return x * self.epsilon

    def shift(self, x):
        if random.random() > self.p:
            return x
        return x + self.epsilon

class ElectricityDataModule(pl.LightningDataModule):
    def __init__(self, path, batch_size, look_window, train_size = 0.6, test_size = 0.4):
        super().__init__()
        self.path = path
        self.batch_size = batch_size
        self.look_window = look_window
        self.train_size = train_size
        self.test_size = test_size
        self.data = pd.read_pickle(path)    

    # def prepare_data(self):
    #     # download

    def setup(self, stage: str):
        # split the data
        self.train_data, self.validate_data ,self.test_data =  np.split(self.data, [int(self.train_size*len(self.data)), int(self.test_size*len(self.data))])
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            self.train = ElectricityDataset(self.train_data, self.look_window)
            self.validate = ElectricityDataset(self.validate_data, self.look_window)

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.test = ElectricityDataset(self.test_data, self.look_window)

        # if stage == "predict":

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, drop_last = True)

    def val_dataloader(self):
        return DataLoader(self.validate, batch_size=self.batch_size, drop_last = True)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size, drop_last = True)

    # def predict_dataloader(self):
        


# Models

## PatchTST

In [11]:
# Utility functions

def create_patches(xb, patch_len, stride):
    """
    xb -> [B x L x M]
    output -> [B x N x M x P], N
    """
    _, num_var, _ = xb.shape
    # compute number of patches
    patch_num = (max(patch_len, num_var)-patch_len) // stride + 2

    # we repeat the last variable of the sequence to have equal patches
    tail = torch.repeat_interleave(xb[:,-1:,:], stride, dim = 1)
    xb = torch.concatenate((xb, tail), axis = 1)

    # create patches
    xb = xb.unfold(dimension=1, size=patch_len, step=stride)

    assert patch_num == xb.shape[1], f"wrong number of computed patches, expected {patch_num} but computed {xb.shape[1]}"

    return xb, patch_num

"""
ref: https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
"""

def positional_encoding(batch_size, patch_num, d_model):
    """
    output -> [B x N x D]
    """
    pe = torch.zeros(batch_size, patch_num, d_model)
    # create a positional array
    position = torch.arange(0, patch_num).unsqueeze(1)
    # div term for half of positions
    div_term = torch.pow(10000.0, torch.arange(0, d_model, 2) / d_model) 
    # even positions
    pe[:, :, 0::2] = torch.sin(position * div_term)
    # odd positions
    pe[:, :, 1::2] = torch.cos(position * div_term)

    # if normalize:
    #     pe = pe - pe.mean()
    #     pe = pe / (pe.std() * 10)
    
    return nn.parameter.Parameter(pe, requires_grad= False)

In [12]:
#PatchTST

class PatchTSTEncoder(nn.Module):
    def __init__(self, num_channels, num_var, patch_len, stride, batch_size, d_model = 128, n_layers = 3, n_heads = 16, dropout = 0.2):
        super(PatchTSTEncoder, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)

        self.num_channels = num_channels
        self.patch_num = (max(patch_len, num_var)-patch_len) // stride + 2
        self.patch_len = patch_len
        self.stride = stride
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.dropout = dropout

        # instance normalization
        """
        ref: https://wandb.ai/wandb_fc/Normalization-Series/reports/Instance-Normalization-in-PyTorch-With-Examples---VmlldzoxNDIyNTQx
        """
        self.inst_norm = nn.InstanceNorm1d(num_channels)

        # patch creation
        self.create_patch = create_patches

        # embedding
        self.W_p = nn.Linear(patch_len, d_model, bias = False)

        # positional encoding
        self.W_pos = positional_encoding(batch_size * num_channels, self.patch_num, d_model)

        # dropout
        self.dropout = nn.Dropout(dropout)

        # encoder
        self.encoders = nn.ModuleList([VanillaTransformerEncoder(d_model) for _ in range(n_layers)])

    def forward(self, x):
        """
        x -> [B x L x M]
        output -> [(B M) x N x D]
        """
        # we need to reshape dimensione before apply instance normalization
        x = einops.rearrange(self.inst_norm(einops.rearrange(x, 'b l m -> b m l')), 'b m l -> b l m')

        # create patches
        x, patch_num = self.create_patch(x, self.patch_len, self.stride)

        assert self.patch_num == patch_num, f"wrong number for patch_num {self.patch_num} != {patch_num}"

        # reshape the tensor from [B x M x P x N] -> [(B M) x P x N]
        x = einops.rearrange(x, 'b n m p -> (b m) n p')
        # now it can be provided to our transformer implementation

        # project into transformer latent space
        x = self.W_p(x) + self.W_pos

        for layer in self.encoders:
            x = layer(x)

        return x



## Vanilla Transformer Encoder

In [13]:
# VanillaTransformer encoder

"""
https://arxiv.org/pdf/1706.03762.pdf
"""

class VanillaTransformerEncoder(nn.Module):
    def __init__(self, d_model, n_heads = 16, dropout = 0.2):
        super(VanillaTransformerEncoder, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)
        
        self.mha = MultiHeadAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model) # maybe batch normalization
        self.dropout1 = nn.Dropout(dropout)

        self.pffn = PositionWiseFeedForwardNetwork(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        # new variable because of residual connection
        z = self.mha(x,x,x)
        z = self.dropout1(z)
        z = self.norm1(z + x)

        # set the new value for the residual connection
        x = z
        z = self.pffn(z)
        z = self.dropout2(z)
        return self.norm2(z + x)

"""
ref: https://d2l.ai/chapter_attention-mechanisms-and-transformers/multihead-attention.html
"""
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads = 16):
        super(MultiHeadAttention, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)

        assert d_model % n_heads == 0, "n_heads must be a multiple of d_model"

        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model, bias = False)
        self.W_k = nn.Linear(d_model, d_model, bias = False)
        self.W_v = nn.Linear(d_model, d_model, bias = False)
        self.W_o = nn.Linear(d_model, d_model, bias = False)

    # reshape to compute in parallel the several heads
    def reshape_vector(self, x, inverse = False):
        """
        x: [B x N x D] || [B x N x H x DIM]
        output: [B x N x H x DIM] || [B x N x D]
        """
        out = None

        if not inverse:
            out = einops.rearrange(x, 'b n (dim h) -> b h n dim', h=self.n_heads)
        else:
            out = einops.rearrange(x, 'b h n dim -> b n (dim h)')

        return out

    """
    ref: https://machinelearningmastery.com/the-transformer-attention-mechanism/
    """

    def scaled_attention(self, q, k, v, dk):
        """
        q: [B x H x N x DIM], k: [B x H x N x DIM] , v: [B x H x N x DIM]
        output: [B x H x N x DIM]
        """
        sqrt_d_k = math.sqrt(dk)

        # using einsum to perform batch matrix multiplication
        score = einops.einsum(q, k, 'b h n d_k, b h n_1 d_k -> b h n n_1') / sqrt_d_k

        weights = F.softmax(score, dim = -1)

        res = einops.einsum(weights, v, 'b h n n_1, b h n_1 d_k -> b h n d_k')

        return res

    def forward(self, q, k, v):
        """
        q, k, v: [B x N x D]
        output: [B x N x D]
        """
        q = self.reshape_vector(self.W_q(q))
        k = self.reshape_vector(self.W_k(k))
        v = self.reshape_vector(self.W_v(v))

        # parallel computation
        out = self.scaled_attention(q, k, v, self.d_k)
        out_concat = self.reshape_vector(out, inverse = True)

        return self.W_o(out_concat)

class PositionWiseFeedForwardNetwork(nn.Module):
    def __init__(self, d_model, d_inner = 256):
        super(PositionWiseFeedForwardNetwork, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)
        
        self.W_1 = nn.Linear(d_model, d_inner)
        self.act = nn.GELU()
        self.W_2 = nn.Linear(d_inner, d_model)
    
    def forward(self, x):
        x = self.W_1(x)
        x = self.act(x)
        return self.W_2(x)

## Cost

In [14]:
#Cost

class Cost(nn.Module):
    def __init__(self, patch_num, batch_size, d_model = 128, d_s = 64, d_t = 64, n_layers = 3, n_heads = 16, dropout = 0.2):
        super(Cost, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)

        # Dropout for seasonal representation output
        self.seasonal_drop = nn.Dropout(0.1)

        # Trend Feature Disentangler
        self.tfd = TrendFeatureDisentangler(d_model, d_t, patch_num)

        # Seasonal Feature Disentangler
        self.sfd = SeasonalFeatureDisentangler(d_model, d_s, patch_num)

    def forward(self, x):
        """
        x: [(B M) x N x D]
        outputs: {[(B M) x N x d_t], [(B M) x N x d_s]}
        """
        out_tfd = self.tfd(x)
        out_svd = self.sfd(x)
        # ipdb.set_trace(context = 10)
        out_svd = self.seasonal_drop(out_svd) 
        return out_tfd, out_svd 

     

# Causal Convolution (dilated)

class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1):
        super(CausalConv1d, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)

        self.kernel_size = kernel_size
        pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=pad, dilation=dilation)

    def forward(self, x):
        """
        input: [(B M) x N x i_C]
        output: [(B M) x N_out x o_C]
        """
        # we need to reshape before applying the convolution
        x = einops.rearrange(x, 'b n i_c -> b i_c n')
        x = self.conv(x)

        # we need to remove the trailing padding zeros (except for the fist layer) from the values
        if self.kernel_size > 1:
            x = x[...,0:-(self.kernel_size-1)]

        # rearrange to the original shape
        x = einops.rearrange(x, 'b o_c n -> b n o_c')

        return x

# TFD

class TrendFeatureDisentangler(nn.Module):
    def __init__(self, d_model, d_t, patch_num):
        super(TrendFeatureDisentangler, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)
        self.d_model = d_model
        self.d_t = d_t
        self.patch_num = patch_num
        
        # https://discuss.pytorch.org/t/causal-convolution/3456/3
        # https://arxiv.org/pdf/1609.03499v2.pdf

        # floor(log(N/2)) autoregressive expert
        self.conv_num = math.floor(math.log2(patch_num / 2)) + 1
        self.convolutions = nn.ModuleList([CausalConv1d(d_model, d_t, 2**i) for i in range(self.conv_num)])

    def avg_pooling(self, input):
        """
        input: [LIST x (B M) x N x d_t]
        """
        return einops.reduce(input, 'list b n d_t -> b n d_t', 'mean')

    def forward(self, x):
        """
        x: [(B M) x N x D]
        output: [(B M) x N x d_t]
        """
        batch_size, patch_num, d_model = x.shape

        assert patch_num == self.patch_num and d_model == self.d_model, "wrong input dimensions"

        # create the result tensor
        out = torch.zeros(self.conv_num, batch_size, patch_num, self.d_t, device = x.device)

        for i, conv in enumerate(self.convolutions):
            out[i,...] = conv(x)

        # apply the average pooling operation
        out = self.avg_pooling(out)

        return out

# SVD

class SeasonalFeatureDisentangler(nn.Module):
    def __init__(self, d_model, d_s, patch_num):
        super(SeasonalFeatureDisentangler, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)

        self.patch_num = patch_num

        # number of frequencies after dft
        self.f = patch_num // 2 + 1

        # discrete fast fourier transform, rfft output contains only the positive frequencies below the Nyquist frequency
        self.dft = torch.fft.rfft

        # Learnable Fourier Layer
        self.fl = FourierLayer(self.f, d_model, d_s, patch_num)

        # inverse of discrete fast fourier transform
        self.idft = torch.fft.irfft



    def forward(self, x):
        """
        x: [(B M) x N x D]
        output: [(B M) x N x d_s]
        """
        # we apply dft along the temporal dimension
        x = self.dft(x, dim = 1)

        assert self.f == x.shape[1], "wrong dimension of dft"

        # apply fourier layer
        x = self.fl(x)

        # compute the inverse of dft to come back to time domain
        x = self.idft(x, n = self.patch_num, dim = 1) # pass also the legth in order to avoid odd-length problems

        return x

class FourierLayer(nn.Module):
    def __init__(self, f, d_model, d_s, patch_num):
        super(FourierLayer, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)

        self.f = f
        self.d_model = d_model

        self.A = nn.Parameter(torch.empty((f, d_model, d_s), dtype=torch.cfloat))
        self.B = nn.Parameter(torch.empty((f, d_s), dtype=torch.cfloat))

    def forward(self, x):
        """
        x: [(B M) x F x D]
        out: [(B M) x F x d_s]
        """
        batch_size, f, _ = x.shape
        
        assert f == self.f, "wrong dimensions of x"

        out = einops.einsum(self.A, x, 'f d d_s, b f d -> b f d_s') + self.B

        return out




## CoPST Encoder

In [15]:
class CoPSTEncoder(nn.Module):
    def __init__(self, patch_num, num_channels, num_var, patch_len, stride, batch_size, d_model = 128, d_s = 64, d_t = 64, n_layers = 3, n_heads = 16, dropout = 0.2):
        super(CoPSTEncoder, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)

        # PatchTST layer (backbone encoder)
        self.ptst = PatchTSTEncoder(num_channels, num_var, patch_len, stride, batch_size)
        assert patch_num == self.ptst.patch_num, f"wrong number of patches ({patch_num} != {self.ptst.patch_num})"
        self.patch_num = self.ptst.patch_num

        # CoST layer (disentangler)
        self.cost = Cost(self.patch_num, batch_size)

    def forward(self, x):
        """
        x: [B x L x M]
        outputs: {[(B M) x N x d_t], [(B M) x N x d_s]}
        """
        x = self.ptst(x)

        x = self.cost(x)

        return x

## CoPST Model

In [16]:
# Constrastive model similar to MoCo
"""
https://arxiv.org/pdf/1911.05722.pdf
https://github.com/facebookresearch/moco/blob/main/moco/builder.py
"""

class CoPSTModel(pl.LightningModule):
    def __init__(self, patch_num, num_channels, look_window, patch_len, stride, batch_size, 
                 comp_dimension = 64, alpha = 0.05, K = 65536, m = 0.999, T = 0.07,
                 lr = 1e-3, om = 0.9, wd = 1e-4):
        super(CoPSTModel, self).__init__()
        self.save_hyperparameters(ignore=['encoder_q', 'encoder_k'])

        self.K = K
        self.m = m
        self.T = T

        self.alpha = alpha

        self.lr = lr
        self.om = om
        self.wd = wd

        self.encoder_q = CoPSTEncoder(patch_num, num_channels, look_window, patch_len, stride, batch_size)
        self.encoder_k = copy.deepcopy(self.encoder_q)

        self.patch_num = self.encoder_q.patch_num

        # projections head for queries and keyes
        self.head_q = nn.Sequential(
            nn.Linear(comp_dimension, comp_dimension),
            nn.ReLU(),
            nn.Linear(comp_dimension, comp_dimension)
        )
        self.head_k = nn.Sequential(
            nn.Linear(comp_dimension, comp_dimension),
            nn.ReLU(),
            nn.Linear(comp_dimension, comp_dimension)
        )

        # initialize the parameters of the keyes encoder and projection head
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False # the keyes encoder will be updated by the momentum update
        
        for param_q, param_k in zip(self.head_q.parameters(), self.head_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False # the head_k will be updated by the momentum update

        # register a dictionary buffer as a queue (decouped from the minibatch size)
        # https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer
        self.register_buffer('queue', F.normalize(torch.randn(comp_dimension, K), dim=0))
        self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))

    def configure_optimizers(self):
        # set parameters of SGD
        optimizer = torch.optim.SGD(self.parameters(), lr = self.lr, momentum = self.om, weight_decay = self.wd)
        # cosine annelling is a wrapper for SGD
        cosine_anneling = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer, T_max = 100)
        return optimizer

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update for key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1 - self.m)

        for param_q, param_k in zip(self.head_q.parameters(), self.head_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1 - self.m)

    def compute_loss(self, q, k, k_negs):
        # compute logits
        # positive logits: Bx1 (one timestamp as postive)
        l_pos = einops.einsum(q, k, 'b c,b c->b').unsqueeze(-1)
        # negative logits: BxK
        l_neg = einops.einsum(q, k_negs, 'b c,c k->b k')

        # logits: Bx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators - first dim of each batch (it will be considered the positive sample)
        # so we can consider this as a classification problem and use the CE
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device = logits.device)
        loss = F.cross_entropy(logits, labels)

        return loss
    
    def get_polar(self, x):
        return (x.abs(), x.angle())

    def instance_contrastive_loss(self, z1, z2):
        B = z1.shape[0]
        z = torch.cat([z1, z2], dim=0)  # 2B x F x d_s
        z = einops.rearrange(z, 'b f d_s -> f b d_s')  # F x 2B x d_s
        sim = einops.einsum(z, z, 'f b_1 d_s, f b_2 d_s -> f b_1 b_2')  # F x 2B x 2B
        logits = torch.tril(sim, diagonal=-1)[:, :, :-1]  # F x 2B x (2B-1)
        logits += torch.triu(sim, diagonal=1)[:, :, 1:]
        logits = -F.log_softmax(logits, dim=-1)
        # log.debug(logits)

        i = torch.arange(B)
        loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2
        return loss

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0, "K must be a multiple of batch_size"

        # replace keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.T

        ptr = (ptr + batch_size) % self.K
        self.queue_ptr[0] = ptr
    
    def training_step(self, batch, batch_idx):
        x_q, x_k = batch
        loss = self.forward(x_q, x_k)

        # logs metrics for each training_step,
        # and the average across the epoch, to the progress bar and logger
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    # def validation_step(self, batch, batch_idx):
    #     x_q, x_k = batch
    #     loss = self.forward(x_q, x_k)

    #     # logs metrics for each training_step,
    #     # and the average across the epoch, to the progress bar and logger
    #     self.log("val_loss", loss)
    #     return loss

    # def test_step(self, batch, batch_idx):
    #     x_q, x_k = batch
    #     loss = self.forward(x_q, x_k)

    #     # logs metrics for each training_step,
    #     # and the average across the epoch, to the progress bar and logger
    #     self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    #     return loss

    def forward(self, x_q, x_k):
        """
        x_q, x_k: [B x L]
        """
        # select a random timestamp
        rand_idx = np.random.randint(0, self.patch_num)

        # trend and seasonal queries
        q_t, q_s = self.encoder_q(x_q)

        q_t = F.normalize(self.head_q(q_t[:, rand_idx]), dim=-1)

        # compute key features
        with torch.no_grad():  # no gradient update for keys (momentum update will be used)
            self._momentum_update_key_encoder()  # update key encoder using momentum
            k_t, k_s = self.encoder_k(x_k)
            k_t = F.normalize(self.head_k(k_t[:, rand_idx]), dim=-1)

        loss = 0

        loss += self.compute_loss(q_t, k_t, self.queue.clone().detach())
        self._dequeue_and_enqueue(k_t)

        q_s = F.normalize(q_s, dim=-1)
        _, k_s = self.encoder_q(x_k)
        k_s = F.normalize(k_s, dim=-1)

        # the frequency and phase lost must be computed in the frequency domain
        q_s_freq = torch.fft.rfft(q_s, dim=1)
        k_s_freq = torch.fft.rfft(k_s, dim=1)
        q_s_amp, q_s_phase = self.get_polar(q_s_freq)
        k_s_amp, k_s_phase = self.get_polar(k_s_freq)

        seasonal_loss = self.instance_contrastive_loss(q_s_amp, k_s_amp) + \
                        self.instance_contrastive_loss(q_s_phase,k_s_phase)
        loss += (self.alpha * (seasonal_loss/2))

        return loss

# Train

## Helper functions

In [48]:
# random seed functions
def seed_everything(seed):
    pl.seed_everything(seed, workers=True)
    random.seed(seed) 
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [49]:
# train function
def train(batch_size, datamodule, model, model_name, max_epochs = 500, checkpoint_every_n_epochs = 5, 
          check_val_every_n_epoch = 5, resume_training = True, load_model = False, 
          enable_checkpoint = True, monitor_metric = "val_loss", checkpoint_dir = None, logs_dir = None, 
          early_stopping = True, deterministic = False):
    
    # check monitor metric
    assert monitor_metric in ["train_loss", "vall_loss"], "metric to monitor is invalid"
    
    # initialize callbacks array
    callbacks = [TQDMProgressBar(refresh_rate=20)]

    # add checkpoints to callbacks
    checkpoint_callback = None
    if enable_checkpoint and checkpoint_dir is not None:
        checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_dir,  monitor = monitor_metric, filename=model_name + '-mnist-{epoch:02d}-{' + monitor_metric + ':.2f}',
                                         save_last =True, every_n_epochs = checkpoint_every_n_epochs, save_on_train_epoch_end = True)
        callbacks.append(checkpoint_callback)
        
    # add early stopping to the callbacks
    if early_stopping:
        callbacks.append(EarlyStopping(monitor="val_loss", min_delta = 0.1, patience = 3, mode="min", check_on_train_epoch_end = False))
                              
    # define the logger object
    logger = None
    if logs_dir is not None:
        logger = TensorBoardLogger(logs_dir, name=model_name)

    # create the Trainer
    trainer = pl.Trainer(enable_checkpointing=enable_checkpoint, devices=1, accelerator="auto", 
                         max_epochs=max_epochs, logger=logger, callbacks=callbacks, 
                         check_val_every_n_epoch = check_val_every_n_epoch, detect_anomaly=True,
                         deterministic = deterministic)

    ckpt_path = None
    if resume_training:
        ckpt_path = checkpoint_dir + "last"
    trainer.fit(ckpt_path = ckpt_path, model=model, datamodule=datamodule)
    if checkpoint_callback is not None:
        log.info(checkpoint_callback.best_model_path)

In [50]:
# initialize dataset
datamodule = ElectricityDataModule(datasets_path[ELECTRICITY] + datasets_processed_name[ELECTRICITY], BATCH_SIZE, L)

# set seed if deterministic
if DETERMINISTIC:
    seed_everything(1)
# initialize model, or load an extisting one
model = CoPSTModel(PATCH_NUM, M, L, PATCH_LEN, STRIDE, BATCH_SIZE, K = BATCH_SIZE * M * DICT_MOMENTUM_SIZE)
if LOAD_MODEL:
    model = CoPSTModel.load_from_checkpoint(CHECKPOINT_FOLDER)

INFO:lightning_fabric.utilities.seed:Global seed set to 1


In [51]:
train(BATCH_SIZE, datamodule, model, CoPST, max_epochs = 100, check_val_every_n_epoch = None, 
      resume_training = False, monitor_metric = "train_loss", checkpoint_dir = CHECKPOINT_FOLDER, 
      logs_dir = LOGS_FOLDER, early_stopping = False, deterministic = DETERMINISTIC)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type         | Params
-------------------------------------------
0 | encoder_q | CoPSTEncoder | 2.0 M 
1 | encoder_k | CoPSTEncoder | 2.0 M 
2 | head_q    | Sequential   | 8.3 K 
3 | head_k    | Sequential   | 8.3 K 
-------------------------------------------
504 K     Trainable params
3.5 M     Non-trainable params
4.0 M     Total params
16.157

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
INFO:APP:
