<a href="https://colab.research.google.com/github/whoami-Lory271/thesis-project/blob/main/thesis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Symbol legend

* B: batch size 
* M: number of channel
* P: patch dimension
* N: number of patches
* L: lookback window


# Installations and imports


In [1]:
!pip install pytorch-lightning==2.0.1.post0 --quiet
!pip install einops==0.6.1 --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m718.6/718.6 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m28.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m40.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m21.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m149.6/149.6 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.5/114.5 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import pandas as pd
import logging
import copy
from google.colab import drive
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cmath
# https://theaisummer.com/einsum-attention/
import einops
import math
from sklearn.model_selection import train_test_split
import pytorch_lightning as pl

In [3]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Constants

In [14]:
# logger
LOG_LEVEL = logging.DEBUG

#paths
ELECTRICITY = "electricity"
ROOT_FOLDER = "/content/drive/MyDrive/Università/Magistrale/Tesi/code"

#hyperparameters
BATCH_SIZE = 16

# Logger

In [15]:
# create logger
log = logging.getLogger('APP')
log.setLevel(LOG_LEVEL)

# # create console handler and set level to debug
# ch = logging.StreamHandler()
# ch.setLevel(logging.INFO)

# # create formatter
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# # add formatter to ch
# ch.setFormatter(formatter)

# # add ch to logger
# logger.addHandler(ch)

In [16]:
# setup logger function
def setup_log(self, level):
    log = logging.getLogger(self.__class__.__name__)
    log.setLevel(level)
    return log

In [17]:
# 'application' code
log.debug('debug message')
log.info('info message')
# logger.warning('warn message')
# logger.error('error message')
# logger.critical('critical message')

DEBUG:APP:debug message
INFO:APP:info message


# Preprocessing

## Datasets

In [None]:
datasets_path = {
    ELECTRICITY: ROOT_FOLDER + "/datasets/electricity"
}

datasets_name = {
    ELECTRICITY: "/LD2011_2014.txt"    
}
datasets_processed_name = {
    ELECTRICITY: "/electricity.pkl"
}

### Electricity

**Preprocessing**

In [None]:
# df = pd.read_csv(datasets_path[ELECTRICITY] + datasets_name[ELECTRICITY], sep = ';')
# df.rename(columns={df.columns[0]: 'Date'},inplace=True)
# df.to_pickle(datasets_path[ELECTRICITY] + datasets_processed_name[ELECTRICITY])

In [None]:
df = pd.read_pickle(datasets_path[ELECTRICITY] + datasets_processed_name[ELECTRICITY])

In [None]:
class ElectricityDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        ts = self.data.iloc[idx, 1:]
        return ts

class ElectricityDataModule(pl.LightningDataModule):
    def __init__(self, path, batch_size, train_size = 0.6, test_size = 0.4):
        super().__init__()
        self.path = path
        data = pd.read_pickle(path)
        self.train_data, self.validate_data ,self.test_data =  np.split(data, [int(train_size*len(data)), int(test_size*len(data))])     

    # def prepare_data(self):
    #     # download

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            self.train = ElectricityDataset(self.train_data)
            self.validate = ElectricityDataset(self.validate_data)

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.test = ElectricityDataset(self.test_data)

        # if stage == "predict":

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, drop_last = True)

    def val_dataloader(self):
        return DataLoader(self.validation, batch_size=self.batch_size, drop_last = True)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size, drop_last = True)

    # def predict_dataloader(self):
        


# Models

## PatchTST

In [18]:
x = torch.randint(20, size = (4,20,2))
print(x.shape)
print("---------------------------------")
tail = x[:,-1:,:]
tail = torch.repeat_interleave(tail, 2, dim = 1)
x = torch.concatenate((x,tail), axis = 1)
print(x.shape)
x = x.unfold(dimension=1, size=5, step=2)
print(x.shape)

torch.Size([4, 20, 2])
---------------------------------
torch.Size([4, 22, 2])
torch.Size([4, 9, 2, 5])


In [19]:
log.debug("my term")
div_term = torch.pow(10000.0, torch.arange(0, 16, 2) / 16) 
print(div_term)
log.debug("other term")
div_term = torch.exp(torch.arange(0, 16, 2) * -(math.log(10000.0) / 16))
print(div_term)

DEBUG:APP:my term
DEBUG:APP:other term


tensor([1.0000e+00, 3.1623e+00, 1.0000e+01, 3.1623e+01, 1.0000e+02, 3.1623e+02,
        1.0000e+03, 3.1623e+03])
tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])


In [20]:
x = torch.randint(20, size = (32,16,7,8), dtype=torch.float32)
y = torch.randint(20, size = (32,16,7,8), dtype=torch.float32)
score = x @ y.transpose(2,3)
log.debug((score @ x).shape)

DEBUG:APP:torch.Size([32, 16, 7, 8])


In [21]:
# Utility functions

def create_patches(xb, patch_len, stride):
    """
    xb -> [B x L x M]
    output -> [B x N x M x P], N
    """
    _, num_var, _ = xb.shape
    # compute number of patches
    patch_num = (max(patch_len, num_var)-patch_len) // stride + 2

    # we repeat the last variable of the sequence to have equal patches
    tail = torch.repeat_interleave(xb[:,-1:,:], stride, dim = 1)
    xb = torch.concatenate((xb, tail), axis = 1)

    # create patches
    xb = xb.unfold(dimension=1, size=patch_len, step=stride)

    assert patch_num == xb.shape[1], f"wrong number of computed patches, expected {patch_num} but computed {xb.shape[1]}"

    return xb, patch_num

"""
ref: https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
"""

def positional_encoding(batch_size, patch_num, d_model):
    """
    output -> [B x N x D]
    """
    pe = torch.zeros(batch_size, patch_num, d_model)
    # create a positional array
    position = torch.arange(0, patch_num).unsqueeze(1)
    # div term for half of positions
    div_term = torch.pow(10000.0, torch.arange(0, d_model, 2) / d_model) 
    # even positions
    pe[:, :, 0::2] = torch.sin(position * div_term)
    # odd positions
    pe[:, :, 1::2] = torch.cos(position * div_term)

    # if normalize:
    #     pe = pe - pe.mean()
    #     pe = pe / (pe.std() * 10)
    
    return nn.parameter.Parameter(pe, requires_grad= False)

In [22]:
#PatchTST

class PatchTSTEncoder(nn.Module):
    def __init__(self, num_channels, num_var, patch_len, stride, batch_size = 16, d_model = 128, n_layers = 3, n_heads = 16, dropout = 0.2):
        super(PatchTSTEncoder, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)

        self.num_channels = num_channels
        self.patch_num = (max(patch_len, num_var)-patch_len) // stride + 2
        self.patch_len = patch_len
        self.stride = stride
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.dropout = dropout

        # instance normalization
        """
        ref: https://wandb.ai/wandb_fc/Normalization-Series/reports/Instance-Normalization-in-PyTorch-With-Examples---VmlldzoxNDIyNTQx
        """
        self.inst_norm = nn.InstanceNorm1d(num_channels)

        # patch creation
        self.create_patch = create_patches

        # embedding
        self.W_p = nn.Linear(patch_len, d_model, bias = False)

        # positional encoding
        self.W_pos = positional_encoding(batch_size * num_channels, self.patch_num, d_model)

        # dropout
        self.dropout = nn.Dropout(dropout)

        # encoder
        self.encoders = nn.ModuleList([VanillaTransformerEncoder(d_model) for _ in range(n_layers)])

    def forward(self, x):
        """
        x -> [B x L x M]
        output -> [(B M) x N x D]
        """
        # we need to reshape dimensione before apply instance normalization
        x = einops.rearrange(self.inst_norm(einops.rearrange(x, 'b l m -> b m l')), 'b m l -> b l m')

        # create patches
        x, patch_num = self.create_patch(x, self.patch_len, self.stride)

        # reshape the tensor from [B x M x P x N] -> [(B M) x P x N]
        x = einops.rearrange(x, 'b n m p -> (b m) n p')
        # now it can be provided to our transformer implementation

        # project into transformer latent space
        x = self.W_p(x) + self.W_pos

        for layer in self.encoders:
            x = layer(x)

        return x



## Vanilla Transformer Encoder

In [23]:
# VanillaTransformer encoder

"""
https://arxiv.org/pdf/1706.03762.pdf
"""

class VanillaTransformerEncoder(nn.Module):
    def __init__(self, d_model, n_heads = 16, dropout = 0.2):
        super(VanillaTransformerEncoder, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)
        
        self.mha = MultiHeadAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model) # maybe batch normalization
        self.dropout1 = nn.Dropout(dropout)

        self.pffn = PositionWiseFeedForwardNetwork(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        # new variable because of residual connection
        z = self.mha(x,x,x)
        z = self.dropout1(z)
        z = self.norm1(z + x)

        # set the new value for the residual connection
        x = z
        z = self.pffn(z)
        z = self.dropout2(z)
        return self.norm2(z + x)

"""
ref: https://d2l.ai/chapter_attention-mechanisms-and-transformers/multihead-attention.html
"""
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads = 16):
        super(MultiHeadAttention, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)

        assert d_model % n_heads == 0, "n_heads must be a multiple of d_model"

        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model, bias = False)
        self.W_k = nn.Linear(d_model, d_model, bias = False)
        self.W_v = nn.Linear(d_model, d_model, bias = False)
        self.W_o = nn.Linear(d_model, d_model, bias = False)

    # reshape to compute in parallel the several heads
    def reshape_vector(self, x, inverse = False):
        """
        x: [B x N x D] || [B x N x H x DIM]
        output: [B x N x H x DIM] || [B x N x D]
        """
        out = None

        if not inverse:
            out = einops.rearrange(x, 'b n (dim h) -> b h n dim', h=self.n_heads)
        else:
            out = einops.rearrange(x, 'b h n dim -> b n (dim h)')

        return out

    """
    ref: https://machinelearningmastery.com/the-transformer-attention-mechanism/
    """

    def scaled_attention(self, q, k, v, dk):
        """
        q: [B x H x N x DIM], k: [B x H x N x DIM] , v: [B x H x N x DIM]
        output: [B x H x N x DIM]
        """
        sqrt_d_k = math.sqrt(dk)

        # using einsum to perform batch matrix multiplication
        score = einops.einsum(q, k, 'b h n d_k, b h n_1 d_k -> b h n n_1') / sqrt_d_k

        weights = F.softmax(score, dim = -1)

        res = einops.einsum(weights, v, 'b h n n_1, b h n_1 d_k -> b h n d_k')

        return res

    def forward(self, q, k, v):
        """
        q, k, v: [B x N x D]
        output: [B x N x D]
        """
        q = self.reshape_vector(self.W_q(q))
        k = self.reshape_vector(self.W_k(k))
        v = self.reshape_vector(self.W_v(v))

        # parallel computation
        out = self.scaled_attention(q, k, v, self.d_k)
        out_concat = self.reshape_vector(out, inverse = True)

        return self.W_o(out_concat)

class PositionWiseFeedForwardNetwork(nn.Module):
    def __init__(self, d_model, d_inner = 256):
        super(PositionWiseFeedForwardNetwork, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)
        
        self.W_1 = nn.Linear(d_model, d_inner)
        self.act = nn.GELU()
        self.W_2 = nn.Linear(d_inner, d_model)
    
    def forward(self, x):
        x = self.W_1(x)
        x = self.act(x)
        return self.W_2(x)

## Cost

In [24]:
#Cost

class Cost(nn.Module):
    def __init__(self, patch_num, batch_size = 16, d_model = 128, d_s = 64, d_t = 64, n_layers = 3, n_heads = 16, dropout = 0.2):
        super(Cost, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)

        # Dropout for seasonal representation output
        self.seasonal_drop = nn.Dropout(0.1)

        # Trend Feature Disentangler
        self.tfd = TrendFeatureDisentangler(d_model, d_t, patch_num)

        # Seasonal Feature Disentangler
        self.sfd = SeasonalFeatureDisentangler(d_model, d_s, patch_num)

    def forward(self, x):
        """
        x: [(B M) x N x D]
        outputs: {[(B M) x N x d_t], [(B M) x N x d_s]}
        """
        return self.tfd(x), self.seasonal_drop(self.sfd(x))

     

# Causal Convolution (dilated)

class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1):
        super(CausalConv1d, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)

        self.kernel_size = kernel_size
        pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=pad, dilation=dilation)

    def forward(self, x):
        """
        input: [(B M) x N x i_C]
        output: [(B M) x N_out x o_C]
        """
        # we need to reshape before applying the convolution
        x = einops.rearrange(x, 'b n i_c -> b i_c n')
        x = self.conv(x)

        # we need to remove the trailing padding zeros (except for the fist layer) from the values
        if self.kernel_size > 1:
            x = x[...,0:-(self.kernel_size-1)]

        # rearrange to the original shape
        x = einops.rearrange(x, 'b o_c n -> b n o_c')

        return x

# TFD

class TrendFeatureDisentangler(nn.Module):
    def __init__(self, d_model, d_t, patch_num):
        super(TrendFeatureDisentangler, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)
        self.d_model = d_model
        self.d_t = d_t
        self.patch_num = patch_num
        
        # https://discuss.pytorch.org/t/causal-convolution/3456/3
        # https://arxiv.org/pdf/1609.03499v2.pdf

        # floor(log(N/2)) autoregressive expert
        self.conv_num = math.floor(math.log2(patch_num / 2)) + 1
        self.convolutions = nn.ModuleList([CausalConv1d(d_model, d_t, 2**i) for i in range(self.conv_num)])

    def avg_pooling(self, input):
        """
        input: [LIST x (B M) x N x d_t]
        """
        return einops.reduce(input, 'list b n d_t -> b n d_t', 'mean')

    def forward(self, x):
        """
        x: [(B M) x N x D]
        output: [(B M) x N x d_t]
        """
        batch_size, patch_num, d_model = x.shape

        assert patch_num == self.patch_num and d_model == self.d_model, "wrong input dimensions"

        # create the result tensor
        out = torch.zeros((self.conv_num, batch_size, patch_num, self.d_t))

        for i, conv in enumerate(self.convolutions):
            out[i,...] = conv(x)

        # apply the average pooling operation
        out = self.avg_pooling(out)

        return out

# SVD

class SeasonalFeatureDisentangler(nn.Module):
    def __init__(self, d_model, d_s, patch_num):
        super(SeasonalFeatureDisentangler, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)

        self.patch_num = patch_num

        # number of frequencies after dft
        self.f = patch_num // 2 + 1

        # discrete fast fourier transform, rfft output contains only the positive frequencies below the Nyquist frequency
        self.dft = torch.fft.rfft

        # Learnable Fourier Layer
        self.fl = FourierLayer(self.f, d_model, d_s, patch_num)

        # inverse of discrete fast fourier transform
        self.idft = torch.fft.irfft



    def forward(self, x):
        """
        x: [(B M) x N x D]
        output: [(B M) x N x d_s]
        """
        # we apply dft along the temporal dimension
        x = self.dft(x, dim = 1)

        assert self.f == x.shape[1], "wrong dimension of dft"

        # apply fourier layer
        x = self.fl(x)

        # compute the inverse of dft to come back to time domain
        x = self.idft(x, n = self.patch_num, dim = 1) # pass also the legth in order to avoid odd-length problems

        return x

class FourierLayer(nn.Module):
    def __init__(self, f, d_model, d_s, patch_num):
        super(FourierLayer, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)

        self.f = f
        self.d_model = d_model

        self.A = nn.Parameter(torch.empty((f, d_model, d_s), dtype=torch.cfloat))
        self.B = nn.Parameter(torch.empty((f, d_s), dtype=torch.cfloat))

    def forward(self, x):
        """
        x: [(B M) x F x D]
        out: [(B M) x F x d_s]
        """
        batch_size, f, _ = x.shape
        
        assert f == self.f, "wrong dimensions of x"

        out = einops.einsum(self.A, x, 'f d d_s, b f d -> b f d_s') + self.B

        return out




In [25]:
input = torch.rand(48, 7, 20)
sfd = SeasonalFeatureDisentangler(20, 10, 7)
out = sfd(input)
log.debug(out.shape)

DEBUG:APP:torch.Size([48, 7, 10])


In [26]:
input = torch.rand(48, 4, 20)
A = torch.rand(4, 20, 10)
B = torch.rand(4, 10)
out = einops.einsum(A, input, 'f d d_s, b f d -> b f d_s') + B
log.debug(out.shape)

DEBUG:APP:torch.Size([48, 4, 10])


## CoPST Encoder

In [27]:
class CoPSTEncoder(nn.Module):
    def __init__(self, num_channels, num_var, patch_len, stride, batch_size = 16, d_model = 128, d_s = 64, d_t = 64, n_layers = 3, n_heads = 16, dropout = 0.2):
        super(CoPSTEncoder, self).__init__()

        self.log = setup_log(self, LOG_LEVEL)

        # PatchTST layer (backbone encoder)
        self.ptst = PatchTSTEncoder(num_channels, num_var, patch_len, stride)
        self.patch_num = self.ptst.patch_num

        # CoST layer (disentangler)
        self.cost = Cost(self.patch_num)

    def forward(self, x):
        """
        x: [B x L x M]
        outputs: {[(B M) x N x d_t], [(B M) x N x d_s]}
        """
        x = self.ptst(x)

        return self.cost(x)

## CoPST Model

In [49]:
# Constrastive model similar to MoCo
"""
https://arxiv.org/pdf/1911.05722.pdf
https://github.com/facebookresearch/moco/blob/main/moco/builder.py
"""

class CoPSTModel(nn.Module):
    def __init__(self, encoder_q, encoder_k, comp_dimension = 64, alpha = 0.05, K = 65536, m = 0.999, T = 0.07):
        super(CoPSTModel, self).__init__()

        self.K = K
        self.m = m
        self.T = T

        self.alpha = alpha

        self.encoder_q = encoder_q
        self.encoder_k = encoder_k

        self.patch_num = encoder_q.patch_num

        # projections head for queries and keyes
        self.head_q = nn.Sequential(
            nn.Linear(comp_dimension, comp_dimension),
            nn.ReLU(),
            nn.Linear(comp_dimension, comp_dimension)
        )
        self.head_k = nn.Sequential(
            nn.Linear(comp_dimension, comp_dimension),
            nn.ReLU(),
            nn.Linear(comp_dimension, comp_dimension)
        )

        # initialize the parameters of the keyes encoder and projection head
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False # the keyes encoder will be updated by the momentum update
        
        for param_q, param_k in zip(self.head_q.parameters(), self.head_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False # the head_k will be updated by the momentum update

        # register a dictionary buffer as a queue (decouped from the minibatch size)
        # https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer
        self.register_buffer('queue', F.normalize(torch.randn(comp_dimension, K), dim=0))
        self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update for key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1 - self.m)

        for param_q, param_k in zip(self.head_q.parameters(), self.head_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1 - self.m)

    def compute_loss(self, q, k, k_negs):
        # compute logits
        # positive logits: Bx1 (one timestamp as postive)
        l_pos = einops.einsum(q, k, 'b c,b c->b').unsqueeze(-1)
        # negative logits: BxK
        l_neg = einops.einsum(q, k_negs, 'b c,c k->b k')

        # logits: Bx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators - first dim of each batch (it will be considered the positive sample)
        # so we can consider this as a classification problem and use the CE
        labels = torch.zeros(logits.shape[0], dtype=torch.long)
        loss = F.cross_entropy(logits, labels)

        return loss
    
    def get_polar(self, x):
        return (x.abs(), x.angle())

    def instance_contrastive_loss(self, z1, z2):
        B = z1.shape[0]
        z = torch.cat([z1, z2], dim=0)  # 2B x F x d_s
        z = einops.rearrange(z, 'b f d_s -> f b d_s')  # F x 2B x d_s
        sim = einops.einsum(z, z, 'f b_1 d_s, f b_2 d_s -> f b_1 b_2')  # F x 2B x 2B
        logits = torch.tril(sim, diagonal=-1)[:, :, :-1]  # F x 2B x (2B-1)
        logits += torch.triu(sim, diagonal=1)[:, :, 1:]
        log.debug(f"logits: {logits.shape}")
        logits = -F.log_softmax(logits, dim=-1)

        i = torch.arange(B)
        log.debug(logits[:, i, B + i - 1].shape)
        loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2
        return loss

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0, "K must be a multiple of batch_size"

        # replace keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.T

        ptr = (ptr + batch_size) % self.K
        self.queue_ptr[0] = ptr

    def forward(self, x_q, x_k):

        # select a random timestamp
        rand_idx = np.random.randint(0, self.patch_num)

        # trend and seasonal queries
        q_t, q_s = self.encoder_q(x_q)

        q_t = F.normalize(self.head_q(q_t[:, rand_idx]), dim=-1)

        # compute key features
        with torch.no_grad():  # no gradient update for keys (momentum update will be used)
            self._momentum_update_key_encoder()  # update key encoder using momentum
            k_t, k_s = self.encoder_k(x_k)
            k_t = F.normalize(self.head_k(k_t[:, rand_idx]), dim=-1)

        loss = 0

        loss += self.compute_loss(q_t, k_t, self.queue.clone().detach())
        self._dequeue_and_enqueue(k_t)

        q_s = F.normalize(q_s, dim=-1)
        _, k_s = self.encoder_q(x_k)
        k_s = F.normalize(k_s, dim=-1)

        # the frequency and phase lost must be computed in the frequency domain
        q_s_freq = torch.fft.rfft(q_s, dim=1)
        k_s_freq = torch.fft.rfft(k_s, dim=1)
        q_s_amp, q_s_phase = self.get_polar(q_s_freq)
        k_s_amp, k_s_phase = self.get_polar(k_s_freq)

        seasonal_loss = self.instance_contrastive_loss(q_s_amp, k_s_amp) + \
                        self.instance_contrastive_loss(q_s_phase,k_s_phase)
        loss += (self.alpha * (seasonal_loss/2))

        return loss

In [36]:
a = torch.rand((16, 64), dtype=torch.cfloat)
b = torch.rand((16, 64))

In [50]:
M = 3
L = 50
P = 5
S = 3
input = torch.rand((16, L, M))
encoder = CoPSTEncoder(M, L, P, S)
model = CoPSTModel(encoder, copy.deepcopy(encoder), K = 48 * 100)
# out_t, out_s = encoder(input)
# log.debug(f"{out_t.shape} {out_s.shape}")
model(input, input)

DEBUG:APP:logits: torch.Size([9, 96, 95])
DEBUG:APP:torch.Size([9, 48])
DEBUG:APP:logits: torch.Size([9, 96, 95])
DEBUG:APP:torch.Size([9, 48])


tensor(0.5496, grad_fn=<AddBackward0>)