In [1]:
import torch
import pandas as pd
import numpy as np

rng = np.random.default_rng()

In [42]:
# Following this tutorial: https://machinelearningmastery.com/joining-the-transformer-encoder-and-decoder-and-masking/
# Modified for PyTorch

#class TransformerAugmentations():
def padding_mask(input, pad_idx=0):
    # Create mask which marks the zero padding values in the input by a 1
    mask = input == pad_idx
    #mask = mask.float()

    return mask

def lookahead_mask(shape):
    # Mask out future entries by marking them with a 1.0
    mask = 1 - torch.tril(torch.ones((shape, shape)))
    mask = mask.masked_fill(mask == 1, float('-inf'))
 
    return mask


In [43]:
# Padding mask for encoder
enc_padding_mask = padding_mask(encoder_input)

# Padding and look-ahead masks for decoder
dec_in_padding_mask = padding_mask(decoder_input)
dec_in_lookahead_mask = lookahead_mask(decoder_input.shape[1])
dec_in_lookahead_mask = torch.maximum(dec_in_padding_mask, dec_in_lookahead_mask)

NameError: name 'encoder_input' is not defined

In [None]:
shape = 10
test_input = torch.ones((shape))
test_input[5:] = 0
p_mask = padding_mask(test_input)
print(p_mask)
print(p_mask.shape)
l_mask = lookahead_mask(shape)
print(l_mask)
print(test_input * p_mask)


tensor([0., 0., 0., 0., 0., 1., 1., 1., 1., 1.])
torch.Size([10])
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import os

def Transformer_train_fn(
        train_loader,
        val_loader,
        encoder_model,
        decoder_model,
        encoder_optimizer,
        decoder_optimizer,
        loss_fn,
        metric_loss_fn,
        num_epoch,
        device,
        save_path,
        writer,
        teacher_force_ratio=1,
        val_interval=100000,
        checkpoint=None,
):
    for epoch in range(num_epoch):
        print(f'===== Epoch: {epoch} =====')
        epoch_train_loss = 0
        epoch_train_metric = 0

        for train_step, train_data in enumerate(train_loader):
            train_source = train_data['encoder_inputs'].to(device)
            train_target = train_data['decoder_inputs'].to(device)

            # Zero optimizers
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            # Forward pass
            decoder_output = torch.zeros(4, 512, 7).to(device)
            train_target_unpacked, _ = torch.nn.utils.rnn.pad_packed_sequence(train_target, batch_first=True)
            train_target_unpacked.to(device)
            start = train_target_unpacked[:, 0, :].unsqueeze(1).to(device)
            teacher_force = True if random.random() < teacher_force_ratio else False

            encoder_hidden, encoder_cell = encoder_model(train_source)
            # print(encoder_hidden.shape)
            encoder_cell = torch.zeros(1, 4, 32).to(device)
            
            if train_step == 0:
                decoder_output, decoder_hidden, decoder_cell = decoder_model(train_target, encoder_hidden, encoder_cell)
                # print(f'Decoder Output: {decoder_output.shape}\t Decoder Hidden: {decoder_hidden.shape}\t Decoder Cell: {decoder_cell.shape}')
            elif train_step !=0 and teacher_force == True:
                decoder_output, decoder_hidden, decoder_cell = decoder_model(train_target, encoder_hidden, encoder_cell)
            elif train_step != 0 and teacher_force == False:
                for i in range(1, 512):
                    start = torch.nn.utils.rnn.pack_sequence(start)
                    decoder_output[:, i, :], decoder_hidden, decoder_cell = decoder_model(start, encoder_hidden, encoder_cell)
                    start = train_target_unpacked[:, i, :].unsqueeze(1)
                    encoder_hidden = decoder_hidden
                    encoder_cell = decoder_cell

            train_loss = loss_fn(decoder_output, train_target_unpacked)

            # Backwards
            train_loss.backward()

            # Update optimizers
            encoder_optimizer.step()
            decoder_optimizer.step()

            # Train loss
            epoch_train_loss += train_loss.item()

            # Train metric loss
            train_metric = metric_loss_fn(decoder_output, train_target_unpacked)
            epoch_train_metric += train_metric

        # Average losses for tensorboard
        epoch_train_loss /= (train_step+1)
        writer.add_scalar('Training MSE per Epoch', epoch_train_loss, epoch)
        epoch_train_metric /= (train_step+1)
        writer.add_scalar('Training MAE per Epoch', epoch_train_metric, epoch)
        

        if epoch+1 % val_interval == 0:
            encoder_model.eval()
            decoder_model.eval()
            with torch.no_grad():
                epoch_val_loss = 0
                epoch_val_metric = 0

                for val_step, val_data in enumerate(val_loader):
                    val_source = val_data['encoder_inputs'].to(device)
                    val_target = val_data['decoder_inputs'].to(device)

                    # Run validation model
                    val_encoder_hidden, val_encoder_cell = encoder_model(val_source)
                    val_decoder_output, val_decoder_hidden, val_decoder_cell = decoder_model(val_target, val_encoder_hidden, val_encoder_cell)

                    val_loss = loss_fn(val_decoder_output, val_target)

                    # Val loss
                    epoch_val_loss += val_loss.item()

                    # Val metric loss
                    val_metric = metric_loss_fn(val_decoder_output, val_target)
                    epoch_val_metric += val_metric

                # Average validation losses for tensorboard
                epoch_val_loss /= (val_step+1)
                writer.add_scalar('Validation MSE per Epoch', epoch_val_loss, epoch)
                epoch_val_metric /= (val_step+1)
                writer.add_scalar('Validation MAE per Epoch', epoch_val_metric, epoch)


                 # Save checkpoint
                if not os.path.exists(os.path.join(save_path, 'checkpoint')):
                    os.makedirs(os.path.join(save_path, 'checkpoint'))
                torch.save({'epoch': epoch,
                            'encoder_model_state_dict': encoder_model.state_dict(),
                            'decoder_model_state_dict': decoder_model.state_dict(),
                            'encoder_optim_state_dict': encoder_optimizer.state_dict(),
                            'decoder_optim_state_dict': decoder_optimizer.state_dict(),
                            'train_loss': epoch_train_loss,
                            'val_loss': epoch_val_loss},
                           os.path.join(save_path, 'checkpoint', 'checkpoint_{}.pth'.format(epoch))
                           )
                
                # Save best model
                if not os.path.exists(os.path.join(save_path, 'best')):
                    os.makedirs(os.path.join(save_path, 'best'))
                if epoch_val_metric > best_metric:
                    best_metric = epoch_val_metric
                    best_metric_epoch = epoch
                    torch.save(encoder_model.state_dict(), os.path.join(save_path, 'best', 'best_encoder_model.pth'))
                    torch.save(decoder_model.state_dict(), os.path.join(save_path, 'best', 'best_decoder_model.pth'))

    writer.close()

# Test dataset class

In [3]:
import glob
valid_files = glob.glob("/root/data/smartwatch/subjects/*/*_full.csv")
len(valid_files)

655

In [31]:

class SmartwatchDataset(torch.utils.data.Dataset):
    def __init__(self, valid_files, sample_period=0.02):
        """
        Parameters:
        -----------
        valid_files: list of filepaths to normalized data
        """
        super().__init__()
        self.data = []
        for file in valid_files:
            df = pd.read_csv(file)
            # Resample the data if needed
            df.index = pd.to_timedelta(df["time"], unit="seconds")
            df = df.drop("time", axis=1)
            df = df.resample(f"{sample_period}S").mean()
            self.data.append(df.values)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        """Returns tuple of (imu, mocap) at index"""
        item = self.data[index]
        imu = item[:, 0:9]  # IMU sensor data [accel, mag, gyro]
        mocap = item[:, 9:]  # Mocap data [pos, quat]
        return imu, mocap
    

class SmartwatchAugmentTransformer:
    """
    Collate function to apply random augmentations to the data
        - Randomly perturb the mocap positions
        - Randomly flip sign of mocap quaternion
        - Add random noise to IMU channels
        - Random crop to the signal (if possible)
    """
    def __init__(self, position_noise=0.2, accel_eps=0.01, gyro_eps=0.01, mag_eps=0.01, max_samples=512):
        """
        Parameters:
        -----------
        position_noise: float, limits on uniform distribution [-p, p] to add position offset to mocap
        accel_eps: float, standard deviation on Gaussian noise added to accelerometer channels
        gyro_eps: float, standard deviation on Gaussian noise added to gyroscope channels
        mag_eps: float, standard deviation on Gaussian noise added to mangetometer channels
        """
        self.position_noise = position_noise
        self.accel_eps = accel_eps
        self.gyro_eps = gyro_eps
        self.mag_eps = mag_eps
        self.max_samples = max_samples

    def _random_crop(self, imu, mocap):
        """
        Apply a random crop of the signal of length self.max_samples to both inputs and labels, if able to
        Due to targets being a shifted version of decoder inputs, we need to account for one extra timepoint
        """
        n, d = imu.shape
        max_offset = n - self.max_samples - 1

        if max_offset > 0:
            offset = rng.choice(max_offset)
            inds = slice(offset, offset + self.max_samples + 1)
            return imu[inds, :], mocap[inds, :]
        else:
            return imu, mocap
        

    def padding_mask(self, input, pad_idx=0, dim=512): 
        # Create mask which marks the zero padding values in the input by a 0
        mask = torch.zeros((dim))
        if input.shape[0] < dim:
            mask[input.shape[0]:] = 1
            return mask.bool()
        #mask = mask.float()

        return mask.bool()


    def lookahead_mask(self, shape):
        # Mask out future entries by marking them with a 1.0
        mask = 1 - torch.tril(torch.ones((shape, shape)))
        mask = mask.masked_fill(mask == 1, float('-inf'))
    
        return mask


    def __call__(self, data):
        """
        Parameters:
        -----------
        data: list of tuple of (imu, mocap) of length batch_size
            imu: np.ndarray, dimensions (n_samples, 9), signal data for IMU accel, gyro, and mag
            mocap: np.ndarray, dimensions (n_samples, 7), position and quaternion data from mocap

        Returns:
        --------
        collated_data: dict of torch.nn.utils.rnn.PackedSequence with keys ["encoder_inputs", "decoder_inputs", "targets"]
        """
        encoder_inputs = []
        decoder_inputs = []
        targets = []
        for (imu, mocap) in data:
            imu, mocap = self._random_crop(imu, mocap)

            n_in, d_in = imu.shape
            n_out, d_out = mocap.shape
            assert n_in == n_out, "IMU and mocap must have the same number of sequence elements"
            assert d_in == 9, f"IMU data has dimensionality {d_in} instead of 9"
            assert d_out == 7, f"Mocap data has dimensionality {d_out} instead of 7"

            # Augment XYZ positions
            offset = rng.uniform(-self.position_noise, self.position_noise, size=(1, 3))
            mocap[:, 0:3] += offset
            # Augment quaternion sign
            sign = rng.choice([-1, 1])
            mocap[:, 4:] *= sign

            accel_noise = rng.normal(loc=0, scale=self.accel_eps, size=(n_in, 3))
            gyro_noise = rng.normal(loc=0, scale=self.gyro_eps, size=(n_in, 3))
            mag_noise = rng.normal(loc=0, scale=self.mag_eps, size=(n_in, 3))

            noise = np.hstack([accel_noise, gyro_noise, mag_noise])
            imu += noise

            # Ensure targets are one timestep shifted wrt inputs
            encoder_inputs.append(torch.FloatTensor(imu[:-1, :]))
            decoder_inputs.append(torch.FloatTensor(mocap[:-1, :]))
            targets.append(torch.FloatTensor(mocap[1:, :]))

        lengths = [len(item) for item in encoder_inputs]
        inds = np.flip(np.argsort(lengths)).copy()  # PackedSequence expects lengths from longest to shortest
        lengths = torch.LongTensor(lengths)[inds]

        # Sort by lengths
        encoder_inputs = [encoder_inputs[i] for i in inds]
        decoder_inputs = [decoder_inputs[i] for i in inds]
        targets = [targets[i] for i in inds]

        # Padding mask for encoder
        enc_padding_mask = [self.padding_mask(input=encoder_inputs[i]) for i in inds]
        enc_lookahead_mask = [self.lookahead_mask(shape=encoder_inputs[i].shape[1]) for i in inds]
        
        # Padding and look-ahead masks for decoder
        dec_in_padding_mask = [self.padding_mask(input=decoder_inputs[i]) for i in inds]
        dec_in_lookahead_mask = [self.lookahead_mask(shape=decoder_inputs[i].shape[1]) for i in inds]
        #dec_in_lookahead_mask = [torch.maximum(dec_in_padding_mask[i], dec_in_lookahead_mask[i]) for i in inds]

        # Pad input, if needed
        for i, length in enumerate(lengths):
            if length != 512:
                print("Dim does not equal 512 - padding sequence") 
                encoder_inputs[i] = nn.functional.pad(encoder_inputs[i], pad=(0, 512 - encoder_inputs[i].shape[0]), mode='constant', value=0)
                decoder_inputs[i] = nn.functional.pad(decoder_inputs[i], pad=(0, 512 - decoder_inputs[i].shape[0]), mode='constant', value=0)
                targets[i] = nn.functional.pad(targets[i], pad=(0, 512 - targets[i].shape[0]), mode='constant', value=0)

        encoder_inputs = torch.stack(encoder_inputs)
        decoder_inputs = torch.stack(decoder_inputs)
        targets = torch.stack(targets)

        enc_padding_mask = torch.stack(enc_padding_mask)
        enc_lookahead_mask = torch.stack(enc_lookahead_mask)
        dec_in_padding_mask = torch.stack(dec_in_padding_mask)
        dec_in_lookahead_mask = torch.stack(dec_in_lookahead_mask)

        collated_data = {
            "encoder_inputs": encoder_inputs,
            "decoder_inputs": decoder_inputs,
            "targets": targets,
            "encoder_padding_mask": enc_padding_mask,
            "decoder_padding_mask": dec_in_padding_mask,
            "decoder_lookahead_mask": dec_in_lookahead_mask,
            "encoder_lookahead_mask": enc_lookahead_mask
        }
        return collated_data

def get_file_lists():
    """Get list of files to pass to dataset class
    Returns:
    --------
    train_files: list of str filepaths to pre-processed train data
    test_files: list of str filepaths to pre-processed test data
    """
    import glob
    valid_files = glob.glob("/root/data/smartwatch/subjects/*/*_full.csv")
    test_subjects = [f"S{n}" for n in [5, 10, 15, 20, 25, 30]]
    test_files = [file for file in valid_files for subject in test_subjects if f"/{subject}/" in file]
    train_files = [file for file in valid_files if file not in set(test_files)]
    return train_files, test_files


In [6]:
dataset = SmartwatchDataset(valid_files)

In [32]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, collate_fn=SmartwatchAugmentTransformer())

In [33]:
batch = next(iter(dataloader))

In [34]:
batch['encoder_padding_mask']

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])

In [25]:
print(batch['encoder_inputs'].data.shape)
print(batch['encoder_padding_mask'].data.shape)
print(batch['encoder_lookahead_mask'].data.shape)
print(batch['decoder_inputs'].data.shape)
print(batch['decoder_padding_mask'].data.shape)
print(batch['decoder_lookahead_mask'].data.shape)

torch.Size([16, 512, 9])
torch.Size([16, 512])
torch.Size([16, 9, 9])
torch.Size([16, 512, 7])
torch.Size([16, 512])
torch.Size([16, 7, 7])


In [123]:
batch[0].data.shape

KeyError: 0

In [None]:
from torch import nn

In [None]:
lstm = nn.LSTM(input_size=9, hidden_size=32, batch_first=True)

In [None]:
outputs = lstm(batch[0])

In [None]:
output, (hidden, cell) = outputs
output.data.shape

torch.Size([8192, 32])