<a href="https://colab.research.google.com/github/prithuls/Anomaly-Transformer/blob/main/Transformers_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Vision Transformer**

In [None]:
import torch
import torch.nn as nn

In [None]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size, patch_size, in_channels= 3, embed_dim= 768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size= patch_size,
            stride= patch_size
        )

    def forward(self, x):
        x = self.proj(x) ## n_sample, embed_dim, patch_size, patch_size
        x = x.flatten(2)
        x = x.transpose(1, 2)

        return x

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, n_heads= 12, qkv_bias= True, attn_p= 0., proj_p= 0.):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim // n_heads
        self.scale = self.head_dim ** (-0.5)

        self.qkv = nn.Linear(dim, dim*3, bias= qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self, x):
        n_samples, n_tokens, dim = x.shape
        if dim != self.dim:
            raise ValueError
        qkv = self.qkv(x) ## n_samples, n_tokens, 3 * dim
        qkv = qkv.reshape(
            n_samples, n_tokens, 3, self.n_heads, self.head_dim
        )
        qkv = qkv.permute(
            2, 0, 3, 1, 4
        )   ## 3, n_samples, n_heads, n_tokens, head_dim
        q, k, v = qkv[0], qkv[1], qkv[2] ## n_samples, n_heads, n_tokens, head_dim
        k_t = k.transpose(-2, -1)

        dp = (q @ k_t) * self.scale ## n_samples, n_heads, n_tokens, n_tokens
        attn = dp.softmax(dim = -1) ## n_samples, n_heads, n_tokens, n_tokens
        attn = self.attn_drop(attn) 

        weighted_avg = attn @ v ## n_samples, n_heads, n_tokens, head_dim
        weighted_avg = weighted_avg.transpose(1, 2) ## n_samples, n_tokens, n_heads, head_dim

        weighted_avg = weighted_avg.flatten(2) ## n_samples, n_tokens, dim
        x = self.proj(weighted_avg)
        x = self.proj_drop(x)

        return x


In [None]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, p= 0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(p)

    def forward(self, x):
        """
        Runs forward pass


        Paremeters
        ----------
        x ## (n_samples, n_tokens, in_features)

        Returns
        -------
        Tensor ## (n_samples, n_tokens, out_features)
        
        """
        x = self.fc1(x) ## n_samples, n_tokens, hidden_features
        x = self.drop(x)
        x = self.fc2(x) ## n_samples, n_tokens, out_features
        x = self.drop(x)

        return x



In [None]:
class Block(nn.Module): 
    def __init__(self, dim, n_heads, mlp_ratio= 4.0, qkv_bias= True, p= 0., attn_p= 0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps= 1e-6)
        self.attn = Attention(
            dim, 
            n_heads = n_heads, 
            qkv_bias = qkv_bias, 
            attn_p = attn_p, 
            proj_p = p
        )
        self.norm2 = nn.LayerNorm(dim, eps= 1e-6)
        hidden_features = int(dim * mlp_ratio)
        self.mlp = MLP(
            in_features = dim,
            hidden_features = hidden_features,
            out_features = dim
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))

        return x

In [None]:
class VisionTransformer(nn.Module):
    def __init__(
        self, 
        img_size= 256,
        patch_size= 16,
        in_channels= 3,
        n_classes= 1000,
        embed_dim= 768,
        depth= 12,
        n_heads= 12,
        mlp_ratio= 4,
        qkv_bias= True,
        p= 0,
        attn_p= 0
    ):
        super().__init__()
        self.patch_embed= PatchEmbed(
            img_size= img_size,
            patch_size= patch_size, 
            in_channels= in_channels, 
            embed_dim= embed_dim
        )
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, 1 + self.patch_embed.n_patches, embed_dim)
        )
        self.pos_drop = nn.Dropout(p= p)
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim= embed_dim, 
                    n_heads= n_heads, 
                    mlp_ratio= mlp_ratio, 
                    qkv_bias= qkv_bias, 
                    p= p, 
                    attn_p= attn_p
                )
                for _ in range(depth)
            ]
        )

        self.norm = nn.LayerNorm(embed_dim, eps= 1e-6)
        self.head = nn.Linear(embed_dim, n_classes)


    def forward(self, x):
        n_samples = x.shape[0]
        x = self.patch_embed(x)

        cls_token = self.cls_token.expand(
            n_samples, -1, -1 
        )
        x = torch.cat((cls_token, x), dim= 1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x)
        cls_token_final = x[:,0]
        x = self.head(cls_token_final)

        return x

## **Swin Transformer**

In [None]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.0


In [None]:
import torch
from torch import nn, einsum
import numpy as np
from einops import rearrange, repeat

In [None]:
## For shifting window

class CyclicShift(nn.Module):
    def __init__(self, displacement):
        super().__init__()
        self.displacement = displacement
    
    def forward(self, x):
        return torch.roll(x, shifts= (self.displacement, self.displacement), dims= (1, 2))

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

In [None]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )  

    def forward(self, x):
        return self.net(x)

In [None]:
## This helps with shifted windows

def create_mask(window_size, displacement, upper_lower, left_right):
    mask = torch.zeros(window_size ** 2, window_size ** 2)

    if upper_lower:
        mask[- displacement * window_size : , : - displacement * window_size] = float('-inf')
        mask[: - displacement * window_size, - displacement * window_size : ] = float('-inf')

    if left_right:
        mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2',
                                h1 = window_size, h2 = window_size)
        mask[:, - displacement :, :, : - displacement] = float("-inf")
        mask[:, : - displacement, :, -displacement : ] = float("-inf")
        mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')

    return mask


In [None]:
## For distances between two patches / windows

def get_relative_distances(window_size):
    indices = torch.tensor(
        np.array(
            [[x, y] for x in range(window_size) for y in range(window_size)]
        )
    )
    distances = indices[None, :, :] - indices[:, None, :]
    return distances

In [None]:
indices = torch.tensor(
      np.array(
          [[x, y] for x in range(2) for y in range(2)]
      )
  )
distances = indices[None, :, :] - indices[:, None, :]

In [None]:
class WindowAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads,
        head_dim,
        shifted,
        window_size,
        relative_pos_embedding
    ):
        super().__init__()
        inner_dim = head_dim * heads
        self.heads = heads
        self.scale = head_dim ** -0.5
        self.window_size = window_size
        self.relative_pos_embedding = relative_pos_embedding
        self.shifted = shifted

        if self.shifted:
            displacement = window_size // 2
            self.cyclic_shift = CyclicShift(-displacement)
            self.cyclic_back_shift = CyclicShift(displacement)

            self.upper_lower_mask = nn.Parameter(
                create_mask(
                    window_size = window_size,
                    displacement = displacement,
                    upper_lower = True,
                    left_right = False
                ), requires_grad= False
            )

            self.left_right_mask = nn.Parameter(
                create_mask(
                    window_size = window_size,
                    displacement = displacement,
                    upper_lower = False,
                    left_right = True
                ), requires_grad= False
            )

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias= False)

        if self.relative_pos_embedding:
            self.relative_indices = get_relative_distances(window_size) + window_size - 1 ## Values get positive here
            self.pos_embedding = nn.Parameter(
                torch.randn(
                    2 * window_size - 1, 2 * window_size - 1
                )
            )
        else:
            self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))

        self.to_out = nn.Linear(inner_dim, dim)



    def forward(self, x):
        if self.shifted:
            x = self.cyclic_shift(x)
        
        b, n_h, n_w, _, h = *x.shape, self.heads

        qkv = self.to_qkv(x).chunk(3, dim= -1)
        nw_h = n_h // self.window_size
        nw_w = n_w // self.window_size

        q, k, v = map(
            lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
                                h = h, w_h = self.window_size, w_w = self.window_size,), qkv
        )

        dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale
        if self.relative_pos_embedding:
            dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
        else:
            dots += self.pos_embedding

        if self.shifted:
            dots[:, :, -nw_w : ] += self.upper_lower_mask
            dots[:, :, nw_w - 1 :: nw_w] += self.left_right_mask

        attn = dots.softmax(dim = -1)
        out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
        out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
                        h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
        
        out = self.to_out(out)

        if self.shifted:
            out = self.cyclic_back_shift(out)

        return out

In [None]:
class SwinBlock(nn.Module):
    def __init__(
        self,
        dim,
        heads,
        head_dim,
        mlp_dim, 
        shifted,
        window_size,
        relative_pos_embedding
    ):
        super().__init__()
        self.attention_block = Residual(
            PreNorm(dim, 
                WindowAttention(
                    dim= dim,
                    heads= heads,
                    head_dim= head_dim,
                    shifted= shifted,
                    window_size= window_size,
                    relative_pos_embedding= relative_pos_embedding
                )
            )
        )
        self.mlp_block = Residual(
            PreNorm(dim, FeedForward(dim= dim, hidden_dim= mlp_dim))
        )

    def forward(self, x):
        x = self.attention_block(x)
        x = self.mlp_block(x)

        return x

In [None]:
class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()
        self.downscaling_factor = downscaling_factor
        self.patch_merge = nn.Unfold(
            kernel_size= downscaling_factor,
            stride= downscaling_factor,
            padding= 0
        )
        self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)

    def forward(self, x):
        b, c, h, w = x.shape
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
        x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
        x = self.linear(x)

        return x

In [None]:
class StageModule(nn.Module):
    def __init__(
        self,
        in_channels,
        hidden_dimension,
        layers,
        downscaling_factor, 
        num_heads,
        head_dim,
        window_size,
        relative_pos_embedding
    ):
        super().__init__()
        assert layers % 2 == 0
        self.patch_partition = PatchMerging(
            in_channels= in_channels,
            out_channels= hidden_dimension,
            downscaling_factor= downscaling_factor
        )
        self.layers = nn.ModuleList([])

        for _ in range(layers // 2):
            self.layers.append(nn.ModuleList([
                SwinBlock(
                    dim= hidden_dimension,
                    heads= num_heads,
                    head_dim= head_dim,
                    mlp_dim= hidden_dimension * 4,
                    shifted= False,
                    window_size= window_size,
                    relative_pos_embedding= relative_pos_embedding
                ),
                SwinBlock(
                    dim= hidden_dimension,
                    heads= num_heads,
                    head_dim= head_dim,
                    mlp_dim= hidden_dimension * 4,
                    shifted= True,
                    window_size= window_size,
                    relative_pos_embedding= relative_pos_embedding
                
                )
            ]))

    def forward(self, x):
        x = self.patch_partition(x)
        for regular_block, shifted_block in self.layers:
            x = regular_block(x)
            x = shifted_block(x)
        

        return x.permute(0, 3, 1, 2)

In [None]:
class SwinTransformer(nn.Module):
    def __init__(
        self,
        *,
        hidden_dim,
        layers,
        heads,
        channels= 3,
        num_classes= 1000,
        head_dim= 32,
        window_size= 7,
        downscaling_factor= (4, 2, 2, 2),
        relative_pos_embedding= True
    ):
        super().__init__()

        self.stage1 = StageModule(
            in_channels= channels,
            hidden_dimension= hidden_dim,
            layers= layers[0],
            downscaling_factor= downscaling_factor[0],
            num_heads= heads[0],
            head_dim= head_dim,
            window_size= window_size,
            relative_pos_embedding= relative_pos_embedding
        )

        self.stage2 = StageModule(
            in_channels= hidden_dim,
            hidden_dimension= hidden_dim * 2,
            layers= layers[1],
            downscaling_factor= downscaling_factor[1],
            num_heads= heads[1],
            head_dim= head_dim,
            window_size= window_size,
            relative_pos_embedding= relative_pos_embedding
        )

        self.stage3 = StageModule(
            in_channels= hidden_dim * 2,
            hidden_dimension= hidden_dim * 4,
            layers= layers[2],
            downscaling_factor= downscaling_factor[2],
            num_heads= heads[2],
            head_dim= head_dim,
            window_size= window_size,
            relative_pos_embedding= relative_pos_embedding
        )

        self.stage4 = StageModule(
            in_channels= hidden_dim * 4,
            hidden_dimension= hidden_dim * 8,
            layers= layers[3],
            downscaling_factor= downscaling_factor[3],
            num_heads= heads[3],
            head_dim= head_dim,
            window_size= window_size,
            relative_pos_embedding= relative_pos_embedding
        )

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim * 8),
            nn.Linear(hidden_dim * 8, num_classes)
        )

    def forward(self, img):
        x = self.stage1(img)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)

        x = x.mean(dim= [2, 3])
        return self.mlp_head(x)

In [None]:
net = SwinTransformer(
    hidden_dim=96,
    layers=(2, 2, 6, 2),
    heads=(3, 6, 12, 24),
    channels=3,
    num_classes=3,
    head_dim=32,
    window_size=7,
    downscaling_factor=(4, 2, 2, 2),
    relative_pos_embedding=True
)
dummy_x = torch.randn(1, 3, 224, 224)

In [None]:
logits = net(dummy_x)  # (1,3)
print(net)

In [None]:
!pip install torchinfo

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchinfo
  Downloading torchinfo-1.7.2-py3-none-any.whl (22 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.7.2


In [None]:
from torchinfo import summary

In [None]:
summary(net, input_size=(32, 3, 224, 224), col_names= ["input_size", "output_size", "num_params"])

Layer (type:depth-idx)                                       Input Shape               Output Shape              Param #
SwinTransformer                                              [32, 3, 224, 224]         [32, 3]                   --
├─StageModule: 1-1                                           [32, 3, 224, 224]         [32, 96, 56, 56]          --
│    └─PatchMerging: 2-1                                     [32, 3, 224, 224]         [32, 56, 56, 96]          --
│    │    └─Unfold: 3-1                                      [32, 3, 224, 224]         [32, 48, 3136]            --
│    │    └─Linear: 3-2                                      [32, 56, 56, 48]          [32, 56, 56, 96]          4,704
│    └─ModuleList: 2-2                                       --                        --                        --
│    │    └─ModuleList: 3-3                                  --                        --                        228,244
├─StageModule: 1-2                                         

In [None]:
from torch.nn.modules.module import _addindent
import torch
import numpy as np
def torch_summarize(model, show_weights=True, show_parameters=True):
    """Summarizes torch model by showing trainable parameters and weights."""
    tmpstr = model.__class__.__name__ + ' (\n'
    for key, module in model._modules.items():
        # if it contains layers let call it recursively to get params and weights
        if type(module) in [
            torch.nn.modules.container.Container,
            torch.nn.modules.container.Sequential
        ]:
            modstr = torch_summarize(module)
        else:
            modstr = module.__repr__()
        modstr = _addindent(modstr, 2)

        params = sum([np.prod(p.size()) for p in module.parameters()])
        weights = tuple([tuple(p.size()) for p in module.parameters()])

        tmpstr += '  (' + key + '): ' + modstr 
        if show_weights:
            tmpstr += ', weights={}'.format(weights)
        if show_parameters:
            tmpstr +=  ', parameters={}'.format(params)
        tmpstr += '\n'   

    tmpstr = tmpstr + ')'
    return tmpstr

In [None]:
print(torch_summarize(net))

In [None]:
def get_output_shape(model, image_dim):
    return model(torch.rand(*(image_dim))).data.shape

In [None]:
get_output_shape(net, (32, 3, 224, 224))

torch.Size([32, 3])

In [None]:
class StageModule(nn.Module):
    def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, num_heads, head_dim, window_size,
                 relative_pos_embedding):
        super().__init__()
        assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted block.'

        self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension,
                                            downscaling_factor=downscaling_factor)

        self.layers = nn.ModuleList([])
        for _ in range(layers // 2):
            self.layers.append(nn.ModuleList([
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
            ]))

    def forward(self, x):
        x = self.patch_partition(x)
        for regular_block, shifted_block in self.layers:
            x = regular_block(x)
            x = shifted_block(x)
        return x.permute(0, 3, 1, 2)

In [None]:
stageM = StageModule(
    in_channels, 
    hidden_dimension, 
    layers, 
    downscaling_factor, 
    num_heads, 
    head_dim, 
    window_size,
    relative_pos_embedding
)

In [None]:
class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()
        self.downscaling_factor = downscaling_factor
        self.patch_merge = nn.Unfold(
            kernel_size= downscaling_factor,
            stride= downscaling_factor,
            padding= 0
        )
        self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)

    def forward(self, x):
        b, c, h, w = x.shape
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
        x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
        x = self.linear(x)

        return x

In [None]:
net = PatchMerging(in_channels=3, out_channels=32,
                  downscaling_factor=4)

In [None]:
summary(net, input_size=(32, 3, 224, 224), col_names= ["input_size", "output_size", "num_params"])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
PatchMerging                             [32, 3, 224, 224]         [32, 56, 56, 32]          --
├─Unfold: 1-1                            [32, 3, 224, 224]         [32, 48, 3136]            --
├─Linear: 1-2                            [32, 56, 56, 48]          [32, 56, 56, 32]          1,568
Total params: 1,568
Trainable params: 1,568
Non-trainable params: 0
Total mult-adds (M): 0.05
Input size (MB): 19.27
Forward/backward pass size (MB): 25.69
Params size (MB): 0.01
Estimated Total Size (MB): 44.96

## GTN

# **Anomaly Transformer**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math, os, pickle, collections, argparse, time
import pandas as pd
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader
from torch.backends import cudnn

### Anomaly Attention

In [None]:
class TriangularCausalMask():
    def __init__(self, B, L, device= 'cpu'):
        mask_shape = [B, 1, L, L]
        with torch.no_grad():
            self._mask = torch.triu(
                torch.ones(
                    mask_shape, dtype= torch.bool
                ), diagonal= 1
            ).to(device)
        
    @property
    def mask(self):
        return self.mask


class AnomalyAttention(nn.Module):
    def __init__(
        self,
        win_size, 
        mask_flag= True,
        scale= None,
        attention_dropout= 0.0,
        output_attention= False
    ):
        super(AnomalyAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)
        window_size = win_size
        self.distances = torch.zeros((window_size, window_size)).cuda()
        for i in range(window_size):
            for j in range(window_size):
                self.distances[i][j] = abs(i - j)
        
    def forward(
        self,
        queries,
        keys,
        values, 
        sigma,
        attn_mask
    ): 
        B, L, H, E = queries.shape    ## Batch_size, 
        _, S, _, D = values.shape  ## Isn't both the queries and values same shape??
        scale = self.scale or 1. / math.sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)
        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device= queries.device)
            scores.masked_fill_(attn_mask.mask, -np.inf)
        attn = scale * scores

        sigma = sigma.transpose(1, 2) # B L H ->  B H L
        window_size = attn.shape[-1]
        sigma = torch.sigmoid(sigma * 5) + 1e-5
        sigma = torch.pow(3, sigma) - 1
        sigma = sigma.unsqueeze(-1).repeat(1, 1, 1, window_size)
        prior = self.distances.unsqueeze(0).unsqueeze(0).repeat(
                    sigma.shape[0], sigma.shape[1], 1, 1
                ).cuda()
        prior = 1.0 / (math.sqrt(2 * math.pi) * sigma) * torch.exp(-prior ** 2 / (2 * sigma ** 2))

        series = self.dropout(torch.softmax(attn, dim= 1))
        V = torch.einsum("bhls,bshd->blhd", series, values)
        if self.output_attention:
            return (V.contiguous(), series, prior, sigma)
        else:
            return (V.contiguous(), None)


In [None]:
class AttentionLayer(nn.Module):
    def __init__(
        self,
        attention,
        d_model,
        n_heads,
        d_keys= None,
        d_values= None
    ):
        super(AttentionLayer, self).__init__()
        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)
        self.norm = nn.LayerNorm(d_model)
        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.sigma_projection = nn.Linear(d_model, n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(
        self,
        queries,
        keys,
        values, 
        attn_mask
    ):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads
        x = queries
        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)
        sigma = self.sigma_projection(x).view(B, L, H)

        out, series, prior, sigma = self.inner_attention(
            queries,
            keys,
            values, 
            sigma,
            attn_mask
        )
        out = out.view(B, L, -1)
        return self.out_projection(out), series, prior, sigma

In [None]:
class EncoderLayer(nn.Module):
    def __init__(
        self,
        attention,
        d_model,
        d_ff= None,
        dropout= 0.1,
        activation= 'relu'
    ):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or d_model * 4
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels= d_model, out_channels= d_ff, kernel_size= 1)
        self.conv2 = nn.Conv1d(in_channels= d_ff, out_channels= d_model, kernel_size= 1)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == 'relu' else F.gelu

    def forward(
        self, 
        x,
        attn_mask= None
    ):  
        new_x, attn, mask, sigma = self.attention(
            x, x, x, attn_mask= attn_mask
        )
        x = x + self.dropout(new_x)
        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))
        return self.norm2(x + y), attn, mask, sigma


class Encoder(nn.Module):
    def __init__(
        self,
        attn_layers,
        norm_layer= None
    ):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.norm = norm_layer

    def forward(
        self, 
        x, 
        attn_mask= None
    ):
        series_list = []
        prior_list = []
        sigma_list = []

        for attn_layer in self.attn_layers:
            x, series, prior, sigma = attn_layer(x, attn_mask= attn_mask)
            series_list.append(series)
            prior_list.append(prior)
            sigma_list.append(sigma)


        if self.norm is not None:
            x = self.norm(x)
        
        return x, series_list, prior_list, sigma_list

class AnomalyTransformer(nn.Module):
    def __init__(
        self, 
        win_size,
        enc_in,
        c_out,
        d_model= 512,
        n_heads= 8,
        e_layers= 3,
        d_ff = 512,
        dropout= 0.0,
        activation= 'gelu',
        output_attention= True
    ):
        super(AnomalyTransformer, self).__init__()
        self.output_attention = output_attention

        self.embedding= DataEmbedding(enc_in, d_model, dropout)
        self.encoder = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        AnomalyAttention(
                            win_size= win_size, 
                            mask_flag= False,
                            scale= None,
                            attention_dropout= dropout,
                            output_attention= output_attention
                        ), 
                        d_model= d_model,
                        n_heads= n_heads,
                    ),
                    d_model= d_model,
                    d_ff= d_ff,
                    dropout= dropout,
                    activation= activation
                ) for l in range(e_layers)
            ],
            norm_layer= torch.nn.LayerNorm(d_model)
        )

        self.projection = nn.Linear(d_model, c_out, bias= True)

    def forward(self, x):
        enc_out = self.embedding(x)
        enc_out, series, prior, sigmas = self.encoder(enc_out)
        enc_out = self.projection(enc_out)

        if self.output_attention:
            return enc_out, series, prior, sigmas
        else:
            return enc_out

### Embedding

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(
        self,
        d_model,
        max_len= 5000
    ):
        super(PositionalEmbedding, self).__init__()
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(
            0, d_model, 2
        ).float() * -(math.log(10000.0) / d_model)).exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:, :x.size(1)]

In [None]:
class TokenEmbedding(nn.Module):
    def __init__(
        self,
        c_in,
        d_model
    ):
        super(TokenEmbedding, self).__init__()
        padding = 1 if torch.__version__ >= '1.5.0' else 2
        self.tokenConv = nn.Conv1d(
            in_channels= c_in,
            out_channels= d_model,
            kernel_size= 3,
            padding= padding,
            padding_mode= 'circular',
            bias= False
        )

        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(
                    m.weight,
                    mode= 'fan_in',
                    nonlinearity= 'leaky_relu'
                )
        
    def forward(self, x):
        x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
        return x


In [None]:
class DataEmbedding(nn.Module):
    def __init__(
        self,
        c_in,
        d_model,
        dropout= 0.0
    ):
        super(DataEmbedding, self).__init__()

        self.value_embedding = TokenEmbedding(
            c_in= c_in,
            d_model= d_model
        )
        self.position_embedding = PositionalEmbedding(
            d_model= d_model
        )

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = self.value_embedding(x) + self.position_embedding(x)
        return self.dropout(x)

### Dataset Loader

In [None]:
class PSMSegLoader(object):
    def __init__(
        self,
        data_path,
        win_size,
        step,
        mode= 'train'
    ):
        self.mode = mode
        self.step = step
        self.win_size = win_size
        self.scaler = StandardScaler()
        
        data = pd.read_csv(os.path.join(data_path, 'train.csv'))
        data = data.values[:, 1 : ]
        data = np.nan_to_num(data)
        self.scaler.fit(data)
        data = self.scaler.transform(data)
        
        test_data = pd.read_csv(os.path.join(data_path, 'test.csv'))
        test_data = test_data.values[:, 1 : ]
        test_data = np.nan_to_num(test_data)
        self.test = self.scaler.transform(test_data)
        self.train = data
        self.val = self.test

        test_labels = pd.read_csv(os.path.join(data_path, 'test_label.csv'))
        self.test_labels = test_labels.values[: , 1 : ]

        print("test: ", self.test.shape)
        print("train: ", self.train.shape)

    def __len__(self):
        if self.mode == 'train':
            return (self.train.shape[0] - self.win_size) // self.step + 1
        elif self.mode == 'val':
            return (self.val.shape[0] - self.win_size) // self.step + 1
        elif self.mode == 'test':
            return (self.test.shape[0] - self.win_size) // self.step + 1
        else:
            return (self.test.shape[0] - self.win_size) // self.win_size + 1

    def __getitem__(self, index):
        index = index * self.step
        if self.mode == 'train':
            return np.float32(self.train[index : index + self.win_size]), \
                  np.float32(self.test_labels[0: self.win_size])
        elif self.mode == 'val':
            return np.float32(self.val[index : index + self.win_size]), \
                  np.float32(self.test_labels[0 : self.win_size])
        elif (self.mode == 'test'):
            return np.float32(self.test[index : index + self.win_size]), \
                np.float32(self.test_labels[index : index + self.win_size])
        else:
            return np.float32(self.test[
                      index // self.step * self.win_size : \
                      index // self.step * self.win_size + self.win_size
                  ]), \
                  np.float32(self.test_labels[
                      index // self.step * self.win_size : \
                      index // self.step * self.win_size + self.win_size
                  ])

### Utilities

In [None]:
def my_kl_loss(p, q):
    res = p * (torch.log(p + 0.0001) - torch.log(q + 0.0001))
    return torch.mean(torch.sum(res, dim=-1), dim=1)


def adjust_learning_rate(optimizer, epoch, lr_):
    lr_adjust = {epoch: lr_ * (0.5 ** ((epoch - 1) // 1))}
    if epoch in lr_adjust.keys():
        lr = lr_adjust[epoch]
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        print('Updating learning rate to {}'.format(lr))


class EarlyStopping:
    def __init__(self, patience=7, verbose=False, dataset_name='', delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.best_score2 = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.val_loss2_min = np.Inf
        self.delta = delta
        self.dataset = dataset_name

    def __call__(self, val_loss, val_loss2, model, path):
        score = -val_loss
        score2 = -val_loss2
        if self.best_score is None:
            self.best_score = score
            self.best_score2 = score2
            self.save_checkpoint(val_loss, val_loss2, model, path)
        elif score < self.best_score + self.delta or score2 < self.best_score2 + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_score2 = score2
            self.save_checkpoint(val_loss, val_loss2, model, path)
            self.counter = 0

    def save_checkpoint(self, val_loss, val_loss2, model, path):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), os.path.join(path, str(self.dataset) + '_checkpoint.pth'))
        self.val_loss_min = val_loss
        self.val_loss2_min = val_loss2


class Solver(object):
    DEFAULTS = {}

    def __init__(self, config):

        self.__dict__.update(Solver.DEFAULTS, **config)

        self.train_loader = get_loader_segment(self.data_path, batch_size=self.batch_size, win_size=self.win_size,
                                               mode='train',
                                               dataset=self.dataset)
        self.vali_loader = get_loader_segment(self.data_path, batch_size=self.batch_size, win_size=self.win_size,
                                              mode='val',
                                              dataset=self.dataset)
        self.test_loader = get_loader_segment(self.data_path, batch_size=self.batch_size, win_size=self.win_size,
                                              mode='test',
                                              dataset=self.dataset)
        self.thre_loader = get_loader_segment(self.data_path, batch_size=self.batch_size, win_size=self.win_size,
                                              mode='thre',
                                              dataset=self.dataset)

        self.build_model()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.criterion = nn.MSELoss()

    def build_model(self):
        self.model = AnomalyTransformer(win_size=self.win_size, enc_in=self.input_c, c_out=self.output_c, e_layers=3)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)

        if torch.cuda.is_available():
            self.model.cuda()

    def vali(self, vali_loader):
        self.model.eval()

        loss_1 = []
        loss_2 = []
        for i, (input_data, _) in enumerate(vali_loader):
            input = input_data.float().to(self.device)
            output, series, prior, _ = self.model(input)
            series_loss = 0.0
            prior_loss = 0.0
            for u in range(len(prior)):
                series_loss += (torch.mean(my_kl_loss(series[u], (
                        prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                               self.win_size)).detach())) + torch.mean(
                    my_kl_loss(
                        (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                self.win_size)).detach(),
                        series[u])))
                prior_loss += (torch.mean(
                    my_kl_loss((prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                       self.win_size)),
                               series[u].detach())) + torch.mean(
                    my_kl_loss(series[u].detach(),
                               (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                       self.win_size)))))
            series_loss = series_loss / len(prior)
            prior_loss = prior_loss / len(prior)

            rec_loss = self.criterion(output, input)
            loss_1.append((rec_loss - self.k * series_loss).item())
            loss_2.append((rec_loss + self.k * prior_loss).item())

        return np.average(loss_1), np.average(loss_2)

    def train(self):

        print("======================TRAIN MODE======================")

        time_now = time.time()
        path = self.model_save_path
        if not os.path.exists(path):
            os.makedirs(path)
        early_stopping = EarlyStopping(patience=3, verbose=True, dataset_name=self.dataset)
        train_steps = len(self.train_loader)

        for epoch in range(self.num_epochs):
            iter_count = 0
            loss1_list = []

            epoch_time = time.time()
            self.model.train()
            for i, (input_data, labels) in enumerate(self.train_loader):

                self.optimizer.zero_grad()
                iter_count += 1
                input = input_data.float().to(self.device)

                output, series, prior, _ = self.model(input)

                # calculate Association discrepancy
                series_loss = 0.0
                prior_loss = 0.0
                for u in range(len(prior)):
                    series_loss += (torch.mean(my_kl_loss(series[u], (
                            prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                   self.win_size)).detach())) + torch.mean(
                        my_kl_loss((prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                           self.win_size)).detach(),
                                   series[u])))
                    prior_loss += (torch.mean(my_kl_loss(
                        (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                self.win_size)),
                        series[u].detach())) + torch.mean(
                        my_kl_loss(series[u].detach(), (
                                prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                       self.win_size)))))
                series_loss = series_loss / len(prior)
                prior_loss = prior_loss / len(prior)

                rec_loss = self.criterion(output, input)

                loss1_list.append((rec_loss - self.k * series_loss).item())
                loss1 = rec_loss - self.k * series_loss
                loss2 = rec_loss + self.k * prior_loss

                if (i + 1) % 100 == 0:
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * ((self.num_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                    iter_count = 0
                    time_now = time.time()

                # Minimax strategy
                loss1.backward(retain_graph=True)
                loss2.backward()
                self.optimizer.step()

            print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(loss1_list)

            vali_loss1, vali_loss2 = self.vali(self.test_loader)

            print(
                "Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} ".format(
                    epoch + 1, train_steps, train_loss, vali_loss1))
            early_stopping(vali_loss1, vali_loss2, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break
            adjust_learning_rate(self.optimizer, epoch + 1, self.lr)

    def test(self):
        self.model.load_state_dict(
            torch.load(
                os.path.join(str(self.model_save_path), str(self.dataset) + '_checkpoint.pth')))
        self.model.eval()
        temperature = 50

        print("======================TEST MODE======================")

        criterion = nn.MSELoss(reduce=False)

        # (1) stastic on the train set
        attens_energy = []
        for i, (input_data, labels) in enumerate(self.train_loader):
            input = input_data.float().to(self.device)
            output, series, prior, _ = self.model(input)
            loss = torch.mean(criterion(input, output), dim=-1)
            series_loss = 0.0
            prior_loss = 0.0
            for u in range(len(prior)):
                if u == 0:
                    series_loss = my_kl_loss(series[u], (
                            prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                   self.win_size)).detach()) * temperature
                    prior_loss = my_kl_loss(
                        (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                self.win_size)),
                        series[u].detach()) * temperature
                else:
                    series_loss += my_kl_loss(series[u], (
                            prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                   self.win_size)).detach()) * temperature
                    prior_loss += my_kl_loss(
                        (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                self.win_size)),
                        series[u].detach()) * temperature

            metric = torch.softmax((-series_loss - prior_loss), dim=-1)
            cri = metric * loss
            cri = cri.detach().cpu().numpy()
            attens_energy.append(cri)

        attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
        train_energy = np.array(attens_energy)

        # (2) find the threshold
        attens_energy = []
        for i, (input_data, labels) in enumerate(self.thre_loader):
            input = input_data.float().to(self.device)
            output, series, prior, _ = self.model(input)

            loss = torch.mean(criterion(input, output), dim=-1)

            series_loss = 0.0
            prior_loss = 0.0
            for u in range(len(prior)):
                if u == 0:
                    series_loss = my_kl_loss(series[u], (
                            prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                   self.win_size)).detach()) * temperature
                    prior_loss = my_kl_loss(
                        (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                self.win_size)),
                        series[u].detach()) * temperature
                else:
                    series_loss += my_kl_loss(series[u], (
                            prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                   self.win_size)).detach()) * temperature
                    prior_loss += my_kl_loss(
                        (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                self.win_size)),
                        series[u].detach()) * temperature
            # Metric
            metric = torch.softmax((-series_loss - prior_loss), dim=-1)
            cri = metric * loss
            cri = cri.detach().cpu().numpy()
            attens_energy.append(cri)

        attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
        test_energy = np.array(attens_energy)
        combined_energy = np.concatenate([train_energy, test_energy], axis=0)
        thresh = np.percentile(combined_energy, 100 - self.anormly_ratio)
        print("Threshold :", thresh)

        # (3) evaluation on the test set
        test_labels = []
        attens_energy = []
        for i, (input_data, labels) in enumerate(self.thre_loader):
            input = input_data.float().to(self.device)
            output, series, prior, _ = self.model(input)

            loss = torch.mean(criterion(input, output), dim=-1)

            series_loss = 0.0
            prior_loss = 0.0
            for u in range(len(prior)):
                if u == 0:
                    series_loss = my_kl_loss(series[u], (
                            prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                   self.win_size)).detach()) * temperature
                    prior_loss = my_kl_loss(
                        (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                self.win_size)),
                        series[u].detach()) * temperature
                else:
                    series_loss += my_kl_loss(series[u], (
                            prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                   self.win_size)).detach()) * temperature
                    prior_loss += my_kl_loss(
                        (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
                                                                                                self.win_size)),
                        series[u].detach()) * temperature
            metric = torch.softmax((-series_loss - prior_loss), dim=-1)

            cri = metric * loss
            cri = cri.detach().cpu().numpy()
            attens_energy.append(cri)
            test_labels.append(labels)

        attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
        test_labels = np.concatenate(test_labels, axis=0).reshape(-1)
        test_energy = np.array(attens_energy)
        test_labels = np.array(test_labels)

        pred = (test_energy > thresh).astype(int)

        gt = test_labels.astype(int)

        print("pred:   ", pred.shape)
        print("gt:     ", gt.shape)

        # detection adjustment
        anomaly_state = False
        for i in range(len(gt)):
            if gt[i] == 1 and pred[i] == 1 and not anomaly_state:
                anomaly_state = True
                for j in range(i, 0, -1):
                    if gt[j] == 0:
                        break
                    else:
                        if pred[j] == 0:
                            pred[j] = 1
                for j in range(i, len(gt)):
                    if gt[j] == 0:
                        break
                    else:
                        if pred[j] == 0:
                            pred[j] = 1
            elif gt[i] == 0:
                anomaly_state = False
            if anomaly_state:
                pred[i] = 1

        pred = np.array(pred)
        gt = np.array(gt)
        print("pred: ", pred.shape)
        print("gt:   ", gt.shape)

        from sklearn.metrics import precision_recall_fscore_support
        from sklearn.metrics import accuracy_score
        accuracy = accuracy_score(gt, pred)
        precision, recall, f_score, support = precision_recall_fscore_support(gt, pred,
                                                                              average='binary')
        print(
            "Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f} ".format(
                accuracy, precision,
                recall, f_score))

        return accuracy, precision, recall, f_score

In [None]:
def get_loader_segment(data_path, batch_size, win_size=100, step=100, mode='train', dataset='KDD'):
    if (dataset == 'SMD'):
        dataset = SMDSegLoader(data_path, win_size, step, mode)
    elif (dataset == 'MSL'):
        dataset = MSLSegLoader(data_path, win_size, 1, mode)
    elif (dataset == 'SMAP'):
        dataset = SMAPSegLoader(data_path, win_size, 1, mode)
    elif (dataset == 'PSM'):
        dataset = PSMSegLoader(data_path, win_size, 1, mode)

    shuffle = False
    if mode == 'train':
        shuffle = True

    data_loader = DataLoader(dataset=dataset,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             num_workers=0)
    return data_loader

### Main

In [None]:
def str2bool(v):
    return v.lower() in ('true')


def main(config):
    cudnn.benchmark = True
    if (not os.path.exists(config.model_save_path)):
        os.mkdir(config.model_save_path)
    solver = Solver(vars(config))

    if config.mode == 'train':
        solver.train()
    elif config.mode == 'test':
        solver.test()

    return solver


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--num_epochs', type=int, default=10)
    parser.add_argument('--k', type=int, default=3)
    parser.add_argument('--win_size', type=int, default=100)
    parser.add_argument('--input_c', type=int, default=25)
    parser.add_argument('--output_c', type=int, default=25)
    parser.add_argument('--batch_size', type=int, default=1024)
    parser.add_argument('--pretrained_model', type=str, default=None)
    parser.add_argument('--dataset', type=str, default='PSM')
    parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'])
    parser.add_argument('--data_path', type=str, default='/content/')
    parser.add_argument('--model_save_path', type=str, default='checkpoints')
    parser.add_argument('--anormly_ratio', type=float, default=4.00)

    config, unknown = parser.parse_known_args()

    args = vars(config)
    print('------------ Options -------------')
    for k, v in sorted(args.items()):
        print('%s: %s' % (str(k), str(v)))
    print('-------------- End ----------------')
    main(config)

------------ Options -------------
anormly_ratio: 4.0
batch_size: 1024
data_path: /content/
dataset: PSM
input_c: 25
k: 3
lr: 0.0001
mode: train
model_save_path: checkpoints
num_epochs: 10
output_c: 25
pretrained_model: None
win_size: 100
-------------- End ----------------
test:  (87841, 25)
train:  (132481, 25)
test:  (87841, 25)
train:  (132481, 25)
test:  (87841, 25)
train:  (132481, 25)
test:  (87841, 25)
train:  (132481, 25)
	speed: 0.6193s/iter; left time: 743.7951s
Epoch: 1 cost time: 79.06656241416931


OutOfMemoryError: ignored