In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from threading import Thread
import os
import glob
import sys
sys.path.append('..')

from util.data import *
from config import cfg
from data import BaseDataset
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.encoding = self.generate_encoding(d_model, max_len)

    def generate_encoding(self, d_model, max_len):
        encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        encoding[:, 0::2] = torch.sin(position * div_term)
        encoding[:, 1::2] = torch.cos(position * div_term)
        encoding = encoding.unsqueeze(0)
        return encoding

    def forward(self, x):
        seq_length = x.size(1)
        return self.encoding[:, :seq_length].to(x.device)

class TransformerEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads):
        super(TransformerEncoder, self).__init__()
        self.embedding = nn.Linear(input_dim, hidden_dim)
        self.pos_encoder = PositionalEncoding(hidden_dim)
        self.encoder_layers = nn.TransformerEncoderLayer(hidden_dim, num_heads)
        self.encoder = nn.TransformerEncoder(self.encoder_layers, num_layers)

    def forward(self, x):
        x = self.embedding(x)
        x = x + self.pos_encoder(x)
        x = x.permute(1, 0, 2)  # (seq_len, batch_size, hidden_dim)
        output = self.encoder(x)
        return output

class TransformerDecoder(nn.Module):
    def __init__(self, output_dim, hidden_dim, num_layers, num_heads):
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Linear(output_dim, hidden_dim)
        self.pos_decoder = PositionalEncoding(hidden_dim)
        self.decoder_layers = nn.TransformerDecoderLayer(hidden_dim, num_heads)
        self.decoder = nn.TransformerDecoder(self.decoder_layers, num_layers)

    def generate_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, x, encoder_output):
        x = self.embedding(x)
        x = x + self.pos_decoder(x)
        x = x.permute(1, 0, 2)  # (seq_len, batch_size, hidden_dim)

        # Generate causal mask
        tgt_mask = self.generate_mask(x.size(0)).to(x.device)

        output = self.decoder(x, encoder_output, tgt_mask=tgt_mask)
        return output

class Transformer(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, num_layers, num_heads):
        super(Transformer, self).__init__()
        self.encoder = TransformerEncoder(input_dim, hidden_dim, num_layers, num_heads)
        self.decoder = TransformerDecoder(output_dim, hidden_dim, num_layers, num_heads)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.pos_encoder = PositionalEncoding(hidden_dim)

    def forward(self, x, y):
        encoder_output = self.encoder(x)
        decoder_output = self.decoder(y, encoder_output)
        output = self.fc(decoder_output)
        # reshape back to batch_size x seq_len x num_channels
        output = output.permute(1, 0, 2)
        return output
    

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    def forward(self, x):
        return x + self.block(x)
    
# Encoder layer
class EncoderLayer(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=(2,2)):
        super(EncoderLayer, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(3,2), stride=(1,1), padding=(2,1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(3,3), stride=scale_factor),
        )
    def forward(self, x):
        x = self.encoder(x.float())
        return x
    
class Conv3DEncoderLayer(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=(2,2)):
        super(Conv3DEncoderLayer, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=(3,3,2), stride=(1,1,1), padding=(1,1,1)),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(3,3,3), stride=scale_factor),
        )
    
# Decoder layer
class DecoderLayer(nn.Module):
    def __init__(self, in_channels, out_channels, last=False, scale_factor=(2,2), output_shape=(1000,20)):
        super(DecoderLayer, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(3,2), stride=(1,1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
        if last:
            self.decoder.append(nn.Upsample(size=output_shape, mode='bilinear', align_corners=False))
        else:
            self.decoder.append(nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False))

    def forward(self, x):
        x = self.decoder(x)
        return x
    
# Construct a model with 3 conv layers 3 residual blocks and 3 deconv layers using the ResNet architecture
class NeuroPose(nn.Module):
    def __init__(self, in_channels=1, num_residual_blocks=3, output_shape=(1000,20)):
        super(NeuroPose, self).__init__()
        
        encoder_channels = [in_channels, 32, 128, 256]
        scale_factors = [(5,2), (4,2), (2,2)]

        self.encoder = self.make_encoder_layers(channels=encoder_channels, scale_factors=scale_factors)

        # get last number of filters from encoder
        resnet_channels = encoder_channels[-1]
        self.resnet = self.make_resnet_layers(channel=resnet_channels, num_layers=num_residual_blocks)
        
        self.decoder = self.make_decoder_layers(channels=encoder_channels[::-1], scale_factors=scale_factors[::-1], output_shape=output_shape)

    def make_encoder_layers(self, channels = [1, 32, 128, 256], scale_factors = [(5,2), (4,2), (2,2)]):
        # sequence of encoder layers
        layers = []
        for i in range(len(channels)-1):
            layers.append(EncoderLayer(channels[i], channels[i+1], scale_factor=scale_factors[i]))

        return nn.Sequential(*layers)

    def make_decoder_layers(self, channels = [256, 128, 32, 16], scale_factors = [(2,2), (4,2), (5,2)], output_shape=(1000,20)):
        # sequence of decoder layers
        layers = []
        for i in range(len(channels)-2):
            layers.append(DecoderLayer(channels[i], channels[i+1], scale_factor=scale_factors[i]))
        layers.append(DecoderLayer(channels[-2], channels[-1], last=True, output_shape=output_shape))

        return nn.Sequential(*layers)

    def make_resnet_layers(self, channel=256, num_layers=3):
        # sequence of resnet layers
        layers = []
        for i in range(num_layers):
            layers.append(ResidualBlock(channel, channel))

        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.resnet(x)
        x = self.decoder(x)
        x =x[:,-1,]
        return x
    
    #load from pretrained weights
    def load_pretrained(self, pretrained_path):
        pretrained_dict = torch.load(pretrained_path, map_location=torch.device('cpu'))
        model_dict = self.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        
        model_dict.update(pretrained_dict) 
        self.load_state_dict(model_dict)

        del pretrained_dict

In [13]:
# Hyperparameters
input_dim = 16    # Replace with the actual size of your input vocabulary
output_dim = 20     # Assuming 3 for x, y, z coordinates in pose estimation
hidden_dim = 256
num_layers = 4
num_heads = 8
lr = 0.001
batch_size = 32
epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
kwargs = {
        'data_path': '../dataset/FPE/S1/p3',
        'seq_len':200,
        'num_channels':16,
        'stride':1,
        'filter_data':True,
        'fs':cfg.DATA.EMG.SAMPLING_RATE,
        'Q':cfg.DATA.EMG.Q,
        'low_freq':cfg.DATA.EMG.LOW_FREQ,
        'high_freq':cfg.DATA.EMG.HIGH_FREQ,
        'notch_freq':cfg.DATA.EMG.NOTCH_FREQ,
        'ica': False,
        'transform': None,
        'target_transform': None,
    }

In [14]:

DATA_SOURCES = {
    'emg': read_emg,
    'leap': read_leap,
}

class EMGLeap(BaseDataset):
    def __init__(self, kwargs):
        super().__init__(**kwargs)

        # read the data
        edf_files, csv_files = self.read_dirs()

        if len(edf_files) == 0:
            raise ValueError(f'No edf files found in {self.data_path}')
        if len(csv_files) == 0:
            raise ValueError(f'No csv files found in {self.data_path}')

        threads = [None] * len(edf_files)
        results = {
            'data': [None] * len(edf_files),
            'label': [None] * len(edf_files),
            'gestures': [None] * len(edf_files)
        }

        #  read the data
        self.data, self.label, self.gestures = [], [], []
        for i in range(len(edf_files)):
            print(f'Reading data from {edf_files[i]} and {csv_files[i]}')
            thread = Thread(target=self.prepare_data, args=(edf_files[i], csv_files[i], results, i))
            threads[i] = thread

        for i in range(len(edf_files)):
            threads[i].start()

        for i in range(len(edf_files)):
            threads[i].join()

        self.data = np.concatenate(results['data'], axis=0)
        self.label = np.concatenate(results['label'], axis=0)


        # to tensor
        self.data = torch.tensor(self.data, dtype=torch.float32)
        self.label = torch.tensor(self.label, dtype=torch.float32)

        if self.transform:
            self.data = self.transform(self.data)

        if self.target_transform:
            self.label = self.target_transform(self.label)

    def read_dirs(self):
        if isinstance(self.data_path, str):
            self.data_path = [self.data_path]
        all_files = []
        for path in self.data_path:
            if not os.path.isdir(path):
                raise ValueError(f'{path} is not a directory')
            else:
                print(f'Reading data from {path}')
                all_files += [f for f in glob.glob(os.path.join(path, '**/*'), recursive=True) if
                              os.path.splitext(f)[1] in ['.edf', '.csv']]

        # # Traverse through all the directories and read the data
        # all_files = [f for f in glob.glob(os.path.join(self.data_path, '**/*'), recursive=True) if os.path.splitext(f)[1] in ['.edf', '.csv']]
        # Separate .edf and .csv files

        edf_files = sorted([file for file in all_files if file.endswith('.edf')])
        csv_files = sorted([file for file in all_files if file.endswith('.csv')])

        return edf_files, csv_files

    def print_dataset_specs(self):
        print("data shape: ", self.data.shape)

    def prepare_data(self, data_path, label_path, results={}, index=0):
        data, annotations, header = DATA_SOURCES['emg'](data_path)
        label, _, _ = DATA_SOURCES['leap'](label_path, rotations=True, positions=False)

        if index == 0:
            # save the column names for the label
            self.label_columns = list(label.columns)
            self.data_columns = list(data.columns)

        # set the start and end of experiment
        start_time = max(min(data.index), min(label.index))
        end_time = min(max(data.index), max(label.index))

        # select only the data between start and end time
        data = data.loc[start_time:end_time]
        label = label.loc[start_time:end_time]

        self.label_columns = list(label.columns)
        # Merge the two DataFrames based on the 'time' column
        merged_df = pd.merge_asof(data, label, on='time', direction='forward')
        data = merged_df[data.columns].to_numpy()
        label = merged_df[label.columns].to_numpy()

        data = torch.tensor(data, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.float32)

        #  convert into shape NxSxC with a sliding window using torch roll
        data = data.unfold(0, self.seq_len, self.stride).permute(0, 2, 1)
        label = label.unfold(0, self.seq_len, self.stride).permute(0, 2, 1)

        # data, label, gestures = create_windowed_dataset(merged_df, annotations, self.seq_len, self.stride)
        #  convert into shape NxSxC with a sliding window


        # normalize the data
        data = self.normalize_and_filter(data)

        results['data'][index] = data
        results['label'][index] = label
    
    
    def normalize_and_filter(self, data=None):

        N, C, L = data.shape
        data_sliced = data.reshape(-1, L)

        # normalize the data
        scaler = StandardScaler()
        data_sliced = scaler.fit_transform(data_sliced)

        print("Filtering data...")
        # filter the data
        if self.filter_data:
            data_sliced = self._filter_data(data_sliced)

        return data_sliced.reshape(N, C, L)
    
    def __getitem__(self, index):
        return self.data[index], self.label[index]
    
    def __len__(self):
        return len(self.data)


In [15]:
dataset = EMGLeap(kwargs=kwargs)
train_idx, test_idx = train_test_split(list(range(len(dataset))), test_size=0.2, random_state=42)

train_dataset = torch.utils.data.Subset(dataset, train_idx)
test_dataset = torch.utils.data.Subset(dataset, test_idx)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Reading data from ../dataset/FPE/S1/p3
Reading data from ../dataset/FPE/S1/p3/fpe_pos3_001_S1_rep0_BT.edf and ../dataset/FPE/S1/p3/fpe_pos3_001_S1_rep0_BT.csv
2024-01-02 11:17:25
Filtering data...


In [16]:
# Instantiate the model
model = Transformer(input_dim, output_dim, hidden_dim, num_layers, num_heads)
model = model.to(device)
# Loss and optimizer (using Mean Absolute Error for regression)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, verbose=True)



In [17]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_parameters(model)

14483988

In [19]:
# Training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0.0

    for batch_idx, (input_seq, target_seq) in enumerate(train_loader):
        optimizer.zero_grad()
        input_seq, target_seq = input_seq.to(device), target_seq.to(device)
        # Forward pass
        output = model(input_seq, target_seq[:, :-1, :])  # Exclude the last pose from the target

        # Compute the loss
        loss = criterion(output, target_seq[:, 1:, :])  # Exclude the first pose from the target
        total_loss += loss.item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        if (batch_idx + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")
            scheduler.step(loss)

    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{epochs}, Average Training Loss: {average_loss:.4f}")

    # Validation loop
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for input_seq, target_seq in val_loader:
            # Forward pass
            input_seq, target_seq = input_seq.to(device), target_seq.to(device)
            output = model(input_seq, target_seq[:, :-1, :])

            # Compute the loss
            loss = criterion(output, target_seq[:, 1:, :])
            val_loss += loss.item()

    average_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch + 1}/{epochs}, Validation Loss: {average_val_loss:.4f}")

Epoch 1/10, Batch 10/635, Loss: 24.1449
Epoch 1/10, Batch 20/635, Loss: 23.3261
Epoch 1/10, Batch 30/635, Loss: 23.4143
Epoch 1/10, Batch 40/635, Loss: 24.2512
Epoch 1/10, Batch 50/635, Loss: 25.6130
Epoch 1/10, Batch 60/635, Loss: 25.3755
Epoch 1/10, Batch 70/635, Loss: 23.4997
Epoch 1/10, Batch 80/635, Loss: 22.0900
Epoch 1/10, Batch 90/635, Loss: 24.3088
Epoch 1/10, Batch 100/635, Loss: 23.3273
Epoch 1/10, Batch 110/635, Loss: 23.1804
Epoch 1/10, Batch 120/635, Loss: 24.2089
Epoch 1/10, Batch 130/635, Loss: 22.4082
Epoch 1/10, Batch 140/635, Loss: 25.0379
Epoch 1/10, Batch 150/635, Loss: 25.2241
Epoch 1/10, Batch 160/635, Loss: 23.7424
Epoch 1/10, Batch 170/635, Loss: 20.6424
Epoch 1/10, Batch 180/635, Loss: 22.4621
Epoch 1/10, Batch 190/635, Loss: 24.6297
Epoch 1/10, Batch 200/635, Loss: 20.2762
Epoch 1/10, Batch 210/635, Loss: 22.8169
Epoch 1/10, Batch 220/635, Loss: 22.4016
Epoch 1/10, Batch 230/635, Loss: 23.2747
Epoch 1/10, Batch 240/635, Loss: 23.3301
Epoch 1/10, Batch 250/635