## Loading Dataset as an input x — Defining Dataset Class
**The paths to the files** are as follows:
- 'dataset/y_test.csv'
- 'dataset/y_train.csv'
- 'dataset/Automotive_Ethernet_with_Attack_original_10_17_20_04_test.pcap'
- 'dataset/Automotive_Ethernet_with_Attack_original_10_17_19_50_training.pcap'

In [1]:
!pip install scapy==2.4.4



In [2]:
from IPython.display import display
from pathlib import Path

import numpy as np
import pandas as pd
from scapy.utils import RawPcapReader
from tqdm import tqdm
from scipy.stats import skew

class TimeseriesGenerator:
    pass # Implemented below

class Dataset:
    def __init__(self, df: pd.DataFrame, trim_etc_protocols=True):
        if trim_etc_protocols:
            self.df = df[df['ProtocolType'] != ''].copy()
        else:
            self.df = df
        assert self.df['abstime'].is_monotonic_increasing
        assert self.df['monotime'].is_monotonic_increasing

    @classmethod
    def _load_towids_dataset(cls, path_pcap, usec_unit, path_csv=None, **kwargs):
        # assert scapy.__version__ == '2.4.4', 'scapy version mismatch.'

        reader = RawPcapReader(str(path_pcap))
        list_output = list()
        for idx, (payload, metadata) in tqdm(enumerate(reader), desc='Parsing the pcap file...'):
            sec, usec, wirelen, caplen = metadata
            list_output.append((sec, usec, wirelen, caplen, payload))
        df_pcap = pd.DataFrame(list_output, columns=['sec', 'usec', 'wirelen', 'caplen', 'payload'])

        if path_csv:
            df_label = pd.read_csv(path_csv, header=None, names=['idx', 'label', 'y_desc'])
            assert df_pcap.shape[0] == df_label.shape[0], \
                f'Record count mismatch. {df_pcap.shape=}, {df_label.shape=}'
            assert (df_label['idx'].diff().bfill() == 1).all(), 'Field `idx` does not increase sequentially.'
            df_label['y'] = df_label['label'].map({'Normal': 0, 'Abnormal': 1})
        else:
            df_label = pd.DataFrame(index=df_pcap.index)
            df_label['y'] = 0
            df_label['y_desc'] = 'Normal'
        abstime = pd.to_datetime(df_pcap['sec'], unit='s') + pd.to_timedelta(df_pcap['usec'], unit=usec_unit)
        dupcounts = abstime.duplicated(keep=False).sum()

        if dupcounts > 0:
            print(f'There were {dupcounts} distinct timestamps.', end=' ')
            for _ in range(100):
                duplicated = abstime.duplicated()
                if duplicated.sum() == 0:
                    break
                abstime[duplicated] += pd.Timedelta(milliseconds=1)
            else:
                raise ValueError('Something went wrong.')
            print(f'-> {_} correction(s).')

        monotime = (abstime - abstime.min()).dt.total_seconds()
        df_pcap['payload'] = df_pcap['payload'].map(lambda x: np.frombuffer(x, dtype='uint8'))

        df: pd.DataFrame = pd.concat([
            abstime.rename('abstime'),
            monotime.rename('monotime'),
            df_pcap[['wirelen', 'caplen', 'payload']],
            df_label[['y', 'y_desc']]
        ], axis=1)

        df = df.sort_values('abstime')
        assert df['abstime'].is_monotonic_increasing
        assert df['monotime'].is_monotonic_increasing

        # Protocol specification
        df['ProtocolType'] = ''
        df.loc[df['wirelen'] == 60, 'ProtocolType'] = 'UDP'
        df.loc[df['wirelen'].isin([68, 90]), 'ProtocolType'] = 'PTP'
        df.loc[df['wirelen'].isin([82, 434]), 'ProtocolType'] = 'AVTP'
        # special treatment
        df.loc[df['y_desc'] == 'P_I', 'ProtocolType'] = 'PTP'

        return cls(df, **kwargs)

    @classmethod
    def towids_train(cls, **kwargs):
        return cls._load_towids_dataset(
            Path('dataset/Automotive_Ethernet_with_Attack_original_10_17_19_50_training.pcap'),
            'ns',
            Path('dataset/y_train.csv'),
            **kwargs
        )

    @classmethod
    def towids_test(cls, **kwargs):
        return cls._load_towids_dataset(
            Path('dataset/Automotive_Ethernet_with_Attack_original_10_17_20_04_test.pcap'),
            'ns',
            Path('dataset/y_test.csv'),
            **kwargs
        )

    def do_label(self, window_size) -> np.ndarray:
        y = self.df.rolling(window=window_size)['y'].max().dropna().astype('int32').values
        assert isinstance(y, np.ndarray)
        return y

    def trim(self, time_start=None, time_end=None, is_absolute=None):
        assert is_absolute is not None
        monotime_min = self.df['monotime'].min()
        monotime_max = self.df['monotime'].max()

        if time_start is not None:
            if is_absolute is False:
                time_start = monotime_min + time_start
            assert monotime_min < time_start
        else:
            time_start = monotime_min

        if time_end is not None:
            if is_absolute is False:
                time_end = monotime_max - time_end
            assert time_end < monotime_max
        else:
            time_end = monotime_max

        df = self.df.query(f'{time_start} <= monotime <= {time_end}').copy()
        # print('Before [{} ~ {}] / Required [{} ~ {}] / After [{} ~ {}]'.format(
        #     monotime_min, monotime_max,
        #     time_start, time_end,
        #     df['monotime'].min(), df['monotime'].max()
        # ))
        return Dataset(df)
        
    # Feature generator 1 (FG1)
    def do_fg1_transition_matrix(self, window_size=2048) -> np.array:
        # When the number of collected packets is n, a numpy array of shape = (n, 3, 3) should be the output
        df = self.df
        # proto_types = sorted(df['ProtocolType'].unique()) # ex) ['AVTP', 'PTP', 'UDP']
        idx = {'AVTP': 0, 'PTP': 1, 'UDP': 2} # ex) {'AVTP': 0, 'PTP': 1, 'UDP': 2}
        N = len(idx) # 3

        # 1. ProtocolType sequence -> integer index
        proto_seq = df['ProtocolType'].map(idx).values # [2, 0, 0, 1, 2]

        # 2. generate T
        def seq_to_transition_matrix(seq):
          T = np.zeros((N, N), dtype=np.float32)
          for i in range(len(seq) - 1):
            a, b = seq[i], seq[i+1]
            T[a, b] += 1
          T /= (len(seq)-1) # normalization
          return T

        if len(proto_seq) < window_size:
          raise ValueError(f"Insufficient data length ({len(proto_seq)}) for window_size {window_size}")

        # checkpoint
        print("Data shape:", proto_seq.shape)
        print("Window size:", window_size)

        # 3. sliding window using TimeseriesGenerator
        generator = TimeseriesGenerator(proto_seq, length=window_size, sampling_rate=1, stride=1, batch_size=1, shuffle=False)

        print("Generator length:", len(generator))
        # if len(generator) == 0:
        #   print("Warning: Generator is empty! Check window_size and data length.")
        #   return np.zeros((0, N, N))

        result = []
        for X, _ in generator:
          seq = X[0] # (window_size, )
          T = seq_to_transition_matrix(seq)
          result.append(T)

        return np.stack(result) # (num_windows, N, N)


    # Feature generator 2 (FG2)
    def do_fg2_payload(self, window_size=2048, byte_start=0x22, byte_end=0x22 + 9) -> np.array:
        '''
        - The paper's strategy is to take 9 bytes from the 0x22th byte for the payload loaded in each packet.  
        - Short payloads should be padded with 0x00.
        - When the number of collected packets is n, a numpy array with shape = (n, 9) should be generated. 
        - FG2 does not need to apply TimeseriesGenerator.
        '''
        assert byte_start < byte_end
        num_bytes = byte_end - byte_start # 9

        payloads = []
        for arr in self.df['payload'].values:
          segment = np.zeros(num_bytes, dtype=np.uint8) # [0, 0, 0, ..., 0]
          arr_len = len(arr)
          for i in range(num_bytes): # 9
            if byte_start + i < arr_len:
              segment[i] = arr[byte_start + i]
          payloads.append(segment / 255.0)

        return np.array(payloads) # (n ,9)


    # Feature generator 3 (FG3)
    def do_fg3_statistics(self, window_size=2048, methods=('mean', 'std', 'skew')) -> np.array:
        '''
        - When the number of collected packets is n, a numpy array of shape=(n, 3, 3) should be generated.
        - The <feature normalization strategy> described at the bottom right of page 5 of the paper must be implemented.
        '''
        df = self.df
        # proto_types = sorted(df['ProtocolType'].unique()) # ex) ['AVTP', 'PTP', 'UDP']
        idx = {'AVTP': 0, 'PTP': 1, 'UDP': 2} # ex) {'AVTP': 0, 'PTP': 1, 'UDP': 2}
        N = len(idx) # ex) 3

        monotime = df['monotime'].values
        protos = df['ProtocolType'].map(idx).values

        # each window is constructed as [window_size * 2]
        generator = TimeseriesGenerator(
            np.stack([monotime, protos], axis=1), # (n, 2)
            length = window_size,
            sampling_rate = 1,
            stride = 1,
            batch_size = 1,
            shuffle = False
            )

        # checkpoint
        print("Data shape:", np.stack([monotime, protos], axis=1).shape)
        print("Window size:", window_size)

        result = []
        for X, _ in generator:
          x_window = X[0] # (window_size, 2)
          t = x_window[:, 0] # first column of 'monotime' [1.0, 1.2, 1.3, 2.0, ...]
          p = x_window[:, 1].astype(int) # second column of 'protos(protocol index)' [0, 0, 1, 0]

          stat_matrix = np.full((N, 3), 1e+7, dtype=np.float32) # Initialize default value to 1e+7

          for i in range(N):
            t_i = t[p == i] # time sequence of the ith protocol / t : [1.0, 1.2, 1.3, 2.0, ...] / p==i : [True, True, False, True, ...] / t[p==i] : [1.0, 1.2, 2.0] => select protocol by this workflow
            if len(t_i) >= 2:
                diffs = np.diff(t_i)
                mean_val = np.mean(diffs)
                stat_matrix[i, 0] = mean_val

                if len(diffs) >= 2:
                    std_val = np.std(diffs)
                    stat_matrix[i, 1] = std_val
                if len(diffs) >= 3:
                    skew_val = np.abs(skew(diffs))
                    stat_matrix[i, 2] = skew_val
            
          stat_matrix = np.where(stat_matrix == 0, 1e-7, stat_matrix)
          stat_matrix = np.log10(stat_matrix)

          result.append(stat_matrix)

        return np.stack(result) # (num_windows, N, 3)

dataset_train = Dataset.towids_train()
dataset_test = Dataset.towids_test()

Parsing the pcap file...: 1203737it [00:00, 1723370.84it/s]
Parsing the pcap file...: 791611it [00:00, 1852064.52it/s]


There were 2 distinct timestamps. -> 1 correction(s).


### 1. Train file

In [3]:
display(dataset_train.df)
print(dataset_train.df['ProtocolType'].value_counts())
print(dataset_train.df['y_desc'].value_counts())

Unnamed: 0,abstime,monotime,wirelen,caplen,payload,y,y_desc,ProtocolType
0,2020-09-12 09:51:04.715221,0.000000,434,434,"[145, 239, 0, 0, 254, 0, 0, 252, 112, 0, 0, 3,...",0,Normal,AVTP
1,2020-09-12 09:51:04.715245,0.000024,434,434,"[145, 239, 0, 0, 254, 0, 0, 252, 112, 0, 0, 3,...",0,Normal,AVTP
2,2020-09-12 09:51:04.715326,0.000105,434,434,"[145, 239, 0, 0, 254, 0, 0, 252, 112, 0, 0, 3,...",0,Normal,AVTP
3,2020-09-12 09:51:04.715450,0.000229,434,434,"[145, 239, 0, 0, 254, 0, 0, 252, 112, 0, 0, 3,...",0,Normal,AVTP
4,2020-09-12 09:51:04.715559,0.000338,434,434,"[145, 239, 0, 0, 254, 0, 0, 252, 112, 0, 0, 3,...",0,Normal,AVTP
...,...,...,...,...,...,...,...,...
1203732,2020-09-12 10:00:16.911784,552.196563,60,60,"[220, 166, 50, 94, 72, 71, 220, 166, 50, 93, 2...",0,Normal,UDP
1203733,2020-09-12 10:00:16.912231,552.197010,60,60,"[220, 166, 50, 94, 72, 71, 220, 166, 50, 93, 2...",0,Normal,UDP
1203734,2020-09-12 10:00:16.912686,552.197465,60,60,"[220, 166, 50, 94, 72, 71, 220, 166, 50, 93, 2...",0,Normal,UDP
1203735,2020-09-12 10:00:16.913172,552.197951,60,60,"[220, 166, 50, 94, 72, 71, 220, 166, 50, 93, 2...",0,Normal,UDP


ProtocolType
UDP     846647
AVTP    287086
PTP      69601
Name: count, dtype: int64
y_desc
Normal    954509
C_D        85466
P_I        64635
F_I        35112
M_F        33765
C_R        29847
Name: count, dtype: int64


### 2. Test file 

In [4]:
display(dataset_test.df)
print(dataset_test.df['ProtocolType'].value_counts())
print(dataset_test.df['y_desc'].value_counts())

Unnamed: 0,abstime,monotime,wirelen,caplen,payload,y,y_desc,ProtocolType
0,2020-09-12 10:02:59.795192,0.000000,434,434,"[145, 239, 0, 0, 254, 0, 0, 252, 112, 0, 0, 3,...",0,Normal,AVTP
1,2020-09-12 10:02:59.810189,0.014997,434,434,"[145, 239, 0, 0, 254, 0, 0, 252, 112, 0, 0, 3,...",0,Normal,AVTP
2,2020-09-12 10:02:59.810205,0.015013,434,434,"[145, 239, 0, 0, 254, 0, 0, 252, 112, 0, 0, 3,...",0,Normal,AVTP
3,2020-09-12 10:02:59.810295,0.015103,434,434,"[145, 239, 0, 0, 254, 0, 0, 252, 112, 0, 0, 3,...",0,Normal,AVTP
4,2020-09-12 10:02:59.810414,0.015222,434,434,"[145, 239, 0, 0, 254, 0, 0, 252, 112, 0, 0, 3,...",0,Normal,AVTP
...,...,...,...,...,...,...,...,...
791606,2020-09-12 10:09:36.422031,396.626839,60,60,"[220, 166, 50, 94, 72, 71, 220, 166, 50, 93, 2...",0,Normal,UDP
791607,2020-09-12 10:09:36.422535,396.627343,60,60,"[220, 166, 50, 94, 72, 71, 220, 166, 50, 93, 2...",0,Normal,UDP
791608,2020-09-12 10:09:36.422997,396.627805,60,60,"[220, 166, 50, 94, 72, 71, 220, 166, 50, 93, 2...",0,Normal,UDP
791609,2020-09-12 10:09:36.423462,396.628270,60,60,"[220, 166, 50, 94, 72, 71, 220, 166, 50, 93, 2...",0,Normal,UDP


ProtocolType
UDP     563731
AVTP    198013
PTP      29580
Name: count, dtype: int64
y_desc
Normal    660490
C_D        41203
C_R        29847
P_I        26013
F_I        16962
M_F        16809
Name: count, dtype: int64


- Create train/validation/test sets by dividing the two packet dump datasets (dataset_train, dataset_test) into different time ranges.
- Organize the number of malicious traffic (intrusion) and normal traffic (benign) in these sets into a table as below.

In [5]:
import pandas as pd

# Arguments to be passed to the do function: 
# [dataset, purpose, start time, end time, whether to remove the last 5 seconds of noise]
args = [
    [dataset_train, 'Train', 5, 60, False],
    [dataset_train, 'Validation', 60, 71.11, False],
    [dataset_train, 'Test', 71.11, None, True],
    [dataset_test, 'Train', 5, 80, False],
    [dataset_test, 'Validation', 80, 91.88, False],
    [dataset_test, 'Test', 91.89, None, True],
]

def do(dataset, purpose, time_start, time_end, trim_last_5sec):
    name = 'Packet dump 1' if dataset is dataset_train else 'Packet dump 2'

    dataset = dataset.trim(time_start, time_end, is_absolute=True) # slice [time_start, time_end] part only in the entire dataset
    if trim_last_5sec: # Remove noise remaining after the data collection step
        dataset = dataset.trim(time_end=5, is_absolute=False)
        time_end = dataset.df['monotime'].max() # Since 'time_end' has changed after removing noise, update it again with the actual maximum time
    a = dataset.df['y'].value_counts() 
    a.name = name 
    a['Purpose'] = purpose 
    a['Time range'] = '[{:.2f}, {:.2f}]'.format(time_start, time_end)
    a = a.rename({0: 'Benign', 1: 'Intrusion'}) # 0 as benign, 1 as intrusion
    a = a.reindex(['Purpose', 'Time range', 'Benign', 'Intrusion'], fill_value=0)
    return a, dataset


list_output = list()
list_dataset_sub = list() ######### From here, you can retrieve the Dataset instance as needed.
for arg in args:
    output, dataset_sub = do(*arg)
    list_output.append(output)
    list_dataset_sub.append(dataset_sub)

df = pd.DataFrame(list_output)
df.index.name = 'Packet dump'
df[['Benign', 'Intrusion']] = df[['Benign', 'Intrusion']].map('{:,}'.format)
df

y,Purpose,Time range,Benign,Intrusion
Packet dump,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Packet dump 1,Train,"[5.00, 60.00]",97715,0
Packet dump 1,Validation,"[60.00, 71.11]",19606,0
Packet dump 1,Test,"[71.11, 547.20]",819586,248080
Packet dump 2,Train,"[5.00, 80.00]",130520,0
Packet dump 2,Validation,"[80.00, 91.88]",19943,0
Packet dump 2,Test,"[91.89, 391.63]",496151,129226


### TimeseriesGenerator

In [6]:
import numpy as np


class TimeseriesGenerator:
    def __init__(self, data, length, sampling_rate=1, stride=1,
                 start_index=0, end_index=None,
                 shuffle=False, reverse=False, batch_size=128, label=None):
        self.data = data
        self.length = length
        self.sampling_rate = sampling_rate
        self.stride = stride
        self.start_index = start_index + length
        if end_index is None:
            end_index = len(data)
        self.end_index = end_index
        self.shuffle = shuffle
        self.reverse = reverse
        self.batch_size = batch_size
        self.label = label if label is None else np.array(label)
        if self.start_index > self.end_index:
            raise ValueError(
                "`start_index+length=%i > end_index=%i` "
                "is disallowed, as no part of the sequence "
                "would be left to be used as current step."
                % (self.start_index, self.end_index)
            )

    def __len__(self):
        return (self.end_index - self.start_index + self.batch_size * self.stride) // (self.batch_size * self.stride)

    def __getitem__(self, index):
        rows = self.__index_to_row__(index)
        samples, y = self.__compile_batch__(rows)
        return samples, y

    def __iter__(self):
        for i in range(len(self)):
            yield self[i]
    
    def __index_to_row__(self, index):  # Returns a list of rows that will compose a given batch (index). len(rows) is equal to the batch size.
        if self.shuffle:
            rows = np.random.randint(self.start_index, self.end_index + 1, size=self.batch_size)
        else:
            i = self.start_index + self.batch_size * self.stride * index
            rows = np.arange(i, min(i + self.batch_size * self.stride, self.end_index + 1), self.stride)
        return rows

    def __compile_batch__(self, rows):  # Generate time series features for each given row.
        samples = np.array([self.data[row - self.length: row: self.sampling_rate] for row in rows])
        if self.reverse:
            samples = samples[:, ::-1, ...]
        if self.length == 1:
            samples = np.squeeze(samples)

        if self.label is None:
            return samples, samples
        else:
            return samples, self.label[rows - self.length]

    @property
    def output_shape(self):
        x, y = self[0]
        return x.shape, y.shape

    @property
    def num_samples(self):
        count = 0
        for x, y in self:
            count += x.shape[0]
        return count

    def __str__(self):
        return '<TimeseriesGenerator data.shape={} / num_batches={:,} / output_shape={}>'.format(
            self.data.shape, len(self), self.output_shape,
        )

    def __repr__(self):
        return self.__str__()


### Define new Dataset Class 
- Converting the shape of each FG1-3, to match the dimension as an input value of the Autoencoder model afterwards
    - x = (T, P, S)

- dataset[i][0] → ((9,), (2048, 9), (9,))
- dataset[i][1] → ((9,), (2048, 9), (9,)) ⇒ x == y since it's an autoencoder model
- dataloader[i][0] → ((b, 9), (b, 2048, 9), (b, 9))

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms # 데이터 전처리(변환) 도구

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"{device} is avaiable")

cuda:0 is avaiable


In [9]:
class AEGenerator(Dataset):
    def __init__(self, T, P, S): # 데이터셋 초기화, 파일 불러오기, 전처리 정의
        '''
        T : nparray (n, 3, 3)
        P : nparray (n, 9)
        S : nparray (n, 3, 3)
        '''
        self.T = T
        self.P = P
        self.S = S
        self.n = T.shape[0]
        pass
    
    def __len__(self): # 데이터셋의 총 샘플 수 리턴
        return self.n

    def __getitem__(self, idx): # 하나의 샘플(x, y)를 리턴
        '''
        T : (3, 3) -> (9,)
        self.T[idx] : nparray(3, 3)
        from_numpy() : nparray -> torch tensor
        .flatten() : 2D (3, 3) -> 1D (9, )
        
        ~~Intuitive Example~~
        self.T[idx] = [[1, 2, 3],
                    [4, 5, 6],
                    [7, 8, 9]]
        Output : tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.])
        '''
        t = torch.from_numpy(self.T[idx].astype('float32')).flatten()
        
        '''
        P : (9,) -> (2048, 9)
        self.P[max(0, idx-2047):idx+1] : 과거 2048개의 row (없으면 시작부터 count)

        e.g. if idx=5, self.P[0:6] -> shape(6, 9)

        padding 작업 : 과거 row 수 < 2048(window size)이면 0으로 채운 tensor를 만들어서 앞에 붙이고, .cat()으로 두 tensor를 concat

        e.g.
        p.shape = (6, 9)
        padding.shape = (2042, 9) # 2048 - 6 = 2042로 계산
        p.shape = (2048, 9) # torch.cat
        '''
        p = torch.from_numpy(self.P[max(0, idx-2047):idx+1].astype('float32'))
        if p.shape[0] < 2048:
            padding = torch.zeros((2048 - p.shape[0], 9), dtype=torch.float32)
            p = torch.cat([padding, p], dim = 0)


        '''
        T와 동일하게 S : (3, 3) -> (9,)
        self.S[idx] = [[0, 1, 0],
                    [1, 0, 1],
                    [0, 1, 0]]
        => tensor([0., 1., 0., 1., 0., 1., 0., 1., 0.])
        '''
        s = torch.from_numpy(self.S[idx].astype('float32')).flatten()

        x = y = (t, p, s)
        
        return x, y

In [10]:
original_dataset = list_dataset_sub[0] # 'train' part in dataset_train
T = original_dataset.do_fg1_transition_matrix()
print("FG1 original shape:", T.shape)
print()
P = original_dataset.do_fg2_payload()
print("FG2 original shape:", P.shape)
print()
S = original_dataset.do_fg3_statistics()
print("FG3 original shape:", S.shape)

Data shape: (97715,)
Window size: 2048
Generator length: 95668
FG1 original shape: (95668, 3, 3)

FG2 original shape: (97715, 9)

Data shape: (97715, 2)
Window size: 2048
FG3 original shape: (95668, 3, 3)


In [11]:
dataset = AEGenerator(T, P, S)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True) # n = 32

In [12]:
for batch_x, _ in dataloader:
    t_batch, p_batch, s_batch = batch_x
    print(f'T shape: {t_batch.shape}') # (32, 9)
    print(f'P shape: {p_batch.shape}') # (32, 2048, 9)
    print(f'S shape: {s_batch.shape}') # (32, 9)
    assert t_batch.shape[0] == p_batch.shape[0] == s_batch.shape[0], "T, P, S must have same number of samples!"
    break

T shape: torch.Size([32, 9])
P shape: torch.Size([32, 2048, 9])
S shape: torch.Size([32, 9])


- So as above, currently the shape of the elements in input tuple x is as below:
    - t : (b, 9)
    - p : (b, 2048, 9)
    - s : (b, 9)
- Encoder Architecture in the paper:
    - T : Flatten -> Dense(64) -> Dense(64)
    - P : Seperable Conv1D -> ... (n layers: depends on w) -> Seperable Conv1D -> Seperable Conv1D(576)
    - S : Flatten -> Dense(64) -> Dense(64)
    - latent vector h : 704 = 64 + 64 + 576
- Decoder Architecture in the paper:
    - latent space h(704) has 3 independent network : T', P', S'
    - T', S' : Dense(64) -> Dense(64) -> Dense(9)
    - P' : Transposed SeperableConv1D n layers (last output's channel=9)

- Pytorch doesn't directly afford SeperableConv1d...
    - let's implement in w/ depthwise + pointwise

In [13]:
class SeperableConv1d(nn.Module):
    def __init__(self, input_c, output_c, kernel_size, **kwargs):
        super().__init__()
        self.depthwise = nn.Conv1d(input_c, input_c, kernel_size, groups=input_c, **kwargs)
        self.pointwise = nn.Conv1d(input_c, output_c, 1)
    
    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

In [14]:
class SeparableConvTranspose1d(nn.Module):
    def __init__(self, input_c, output_c, kernel_size, **kwargs):
        super().__init__()
        self.depthwise = nn.ConvTranspose1d(input_c, input_c, kernel_size, groups=input_c, **kwargs)
        self.pointwise = nn.ConvTranspose1d(input_c, output_c, 1)
    
    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

In [18]:
class Autoencoder(nn.Module):
    def __init__(self, num_p_layers=8):
        super(Autoencoder, self).__init__()

        # 1. Encoder definition
        self.t_fc1 = nn.Linear(9, 64)
        self.t_fc2 = nn.Linear(64, 64)

        self.s_fc1 = nn.Linear(9, 64)
        self.s_fc2 = nn.Linear(64, 64)

        p_layers = []
        in_ch = 9
        for i in range(num_p_layers):
            out_ch = 64 if i < num_p_layers - 1 else 576
            p_layers.append(SeperableConv1d(in_ch, out_ch, kernel_size=3, padding=1))
            in_ch = out_ch
        self.p_layers = nn.ModuleList(p_layers)

        
        # 2. Decoder definition
        self.t_dec_fc1 = nn.Linear(704, 64)
        self.t_dec_fc2 = nn.Linear(64, 64)
        self.t_dec_fc3 = nn.Linear(64, 9)

        self.s_dec_fc1 = nn.Linear(704, 64)
        self.s_dec_fc2 = nn.Linear(64, 64)
        self.s_dec_fc3 = nn.Linear(64, 9)

        p_dec_layers = []
        in_ch = 704
        for i in range(num_p_layers):
            out_ch = 64 if i < num_p_layers - 1 else 9
            p_dec_layers.append(SeparableConvTranspose1d(in_ch, out_ch, kernel_size=3, padding=1))
            in_ch = out_ch
        self.p_dec_layers = nn.ModuleList(p_dec_layers)
    
    def forward(self, x):
        t, p, s = x

        # 1. Encoder
        t = F.relu(self.t_fc1(t))
        t = F.relu(self.t_fc2(t))

        s = F.relu(self.s_fc1(s))
        s = F.relu(self.s_fc2(s))

        p = p.permute(0, 2, 1) # (b, 9, 2048)
        for layer in self.p_layers[:-1]:
            p = F.relu(layer(p))
        p = self.p_layers[-1](p)
        p = p.mean(dim=2) # (b, 576)


        # latent vector
        h = torch.cat([t, p, s], dim=1)


        # 2. Decoder
        t = F.relu(self.t_dec_fc1(h))
        t = F.relu(self.t_dec_fc2(t))
        t = self.t_dec_fc3(t)

        s = F.relu(self.s_dec_fc1(h))
        s = F.relu(self.s_dec_fc2(s))
        s = self.s_dec_fc3(s)

        p = h. unsqueeze(2).repeat(1, 1, 2048) # (b, 704, 2048)
        for layer in self.p_dec_layers[:-1]:
            p = F.relu(layer(p))
        p = self.p_dec_layers[-1](p) # (b, 9, 2048)
        p = p.permute(0, 2, 1) # (b, 2048, 9)

        return t, p, s

In [16]:
model = Autoencoder().to(device)
# loss function
criterion = nn.MSELoss() 
# optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [19]:
# Just a simple test~
model = Autoencoder(num_p_layers=8)
t = torch.randn(8, 9)
p = torch.randn(8, 2048, 9)
s = torch.randn(8, 9)

h = model((t, p, s))

In [20]:
epoch1 = 2 # 논문에서는 20이지만 일단 작은 수로 잘 돌아가는지 테스트
for epoch in range(epoch1):
    loss_AE = 0.0              # loss값 누적할 변수 초기화
    for x, _ in dataloader:
        t, p, s = (item.to(device) for item in x) # 배치의 입력 데이터를 gpu로 보냄
        t, p, s = x

        optimizer.zero_grad()       # 이전 배치에서 계산된 기울기 초기화
        t_hat, p_hat, s_hat = model(x)
        loss = ((t - t_hat)**2 + (p - p_hat)**2 + (s - s_hat)**2).mean() # MSE

        '''
        # T loss
        loss_t = F.mse_loss(t_hat, t, reduction='mean')  # PyTorch의 mean은 전체 평균
        # P loss
        loss_p = F.mse_loss(p_hat, p, reduction='mean')
        # S loss
        loss_s = F.mse_loss(s_hat, s, reduction='mean')

        # loss 합
        loss = loss_t + loss_p + loss_s        
        '''

        loss.backward()             # 역전파를 통해 각 파라미터에 대한 기울기를 계산
        optimizer.step()            # weight update
        loss_AE += loss.item() # batch의 loss값을 누적해 epoch 전체 loss를 계산
    cost = loss_AE / len(dataloader) # average loss 값 계산
    if (epoch + 1) % 10 == 0 or epoch == 0:
        print(f"[{epoch + 1}] loss: {cost:.3f}") # 현재 epoch와 평균 손실 값을 출력


RuntimeError: The size of tensor a (32) must match the size of tensor b (2048) at non-singleton dimension 1