In [38]:
from data.stroke_dataset import Stroke_dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from losses.flex_loss import mdn_loss_function
BATCH_SIZE = 16
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

In [35]:
def loss_fn(params, state, batch):
    inputs, _ = batch
    x = inputs[:, :-1]
    y = inputs[:, 1:]
    logits = state.apply_fn({'params': params}, x)
    loss = mdn_loss_function(logits, y)
    return loss

def collate_fn(batch):
    stroke_sequences, sentences = zip(*batch)
    stroke_sequences_padded = pad_sequence([s.clone().detach() if isinstance(
        s, torch.Tensor) else torch.tensor(s) for s in stroke_sequences], batch_first=True, padding_value=0)

    max_sentence_len = max(sentence.shape[0] for sentence in sentences)
    sentences_padded = torch.stack([F.pad(sentence,
                                          (0, 0, 0, max_sentence_len -
                                           sentence.shape[0]),
                                          "constant", 0) for sentence in sentences])
    return stroke_sequences_padded, sentences_padded


In [28]:
# 2. Load dataset
stroke_data = Stroke_dataset(train=True)
dataset_size = len(stroke_data)
train_size = int(0.8 * dataset_size)
val_size = int(0.1 * dataset_size)
test_size = dataset_size - train_size - val_size

# stroke_data.sort_by_sequence_length()



In [55]:
# Example of checking if the dataset is sorted by sequence length

# 1. Inspect a few samples before sorting
print("Before Sorting:")
print("First item length:", len(stroke_data[0][0]))  # Assuming the sequence is the first element in the tuple
print("Middle item length:", len(stroke_data[dataset_size // 2][0]))
print("Last item length:", len(stroke_data[-1][0]))

# 2. Apply sorting
stroke_data.sort_by_sequence_length()

# 3. Inspect the same samples after sorting
print("\nAfter Sorting:")
print("First item length:", len(stroke_data[0][0]))
print("Middle item length:", len(stroke_data[dataset_size // 2][0]))
print("Last item length:", len(stroke_data[-1][0]))

# 4. Additional check with entire dataset (optional)
prev_length = len(stroke_data[0][0])
sorted_correctly = True
for i in range(1, dataset_size):
    current_length = len(stroke_data[i][0])
    if current_length > prev_length:  # Change this condition based on your sorting order
        sorted_correctly = False
        print("length at index", i, " ", current_length, "is smaller than the previous length" , prev_length)
        # break
    prev_length = current_length

print("Is sorted correctly:", sorted_correctly)


Before Sorting:
First item length: 1191
Middle item length: 627
Last item length: 301

After Sorting:
First item length: 1191
Middle item length: 627
Last item length: 301
Is sorted correctly: True


In [64]:
def interleaved_split(dataset, train_frac, val_frac):
    train_indices, val_indices, test_indices = [], [], []
    total_size = len(dataset)

    # Calculate the number of samples for train and validation sets
    train_size = int(total_size * train_frac)
    val_size = int(total_size * val_frac)

    for idx in range(total_size):
        if len(train_indices) < train_size:
            train_indices.append(idx)
        elif len(val_indices) < val_size:
            val_indices.append(idx)
        else:
            test_indices.append(idx)

    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)
    test_dataset = torch.utils.data.Subset(dataset, test_indices)

    return train_dataset, val_dataset, test_dataset

In [71]:
train_dataset, val_dataset, test_dataset = interleaved_split(
    stroke_data, 0.8, 0.1)

# train_loader = DataLoader(
#     train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, collate_fn=collate_fn)


# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=collate_fn)

stroke_loader = DataLoader(stroke_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=collate_fn)

# Function to test padding
def test_padding(data_loader):
    for batch_idx, (stroke_sequences_padded, _) in enumerate(data_loader):
        print(f"Batch {batch_idx + 1}:")

        # Check lengths of stroke sequences
        print("  Stroke Sequence Lengths:")
        for i, seq in enumerate(stroke_sequences_padded):
            print(f"    Sequence {i + 1} length: {seq.shape[0]}")  # Assuming the sequence length is in the second dimension

# Call the test function
test_padding(train_loader)

Batch 1:
  Stroke Sequence Lengths:
    Sequence 1 length: 1191
    Sequence 2 length: 1191
    Sequence 3 length: 1191
    Sequence 4 length: 1191
    Sequence 5 length: 1191
    Sequence 6 length: 1191
    Sequence 7 length: 1191
    Sequence 8 length: 1191
    Sequence 9 length: 1191
    Sequence 10 length: 1191
    Sequence 11 length: 1191
    Sequence 12 length: 1191
    Sequence 13 length: 1191
    Sequence 14 length: 1191
    Sequence 15 length: 1191
    Sequence 16 length: 1191
Batch 2:
  Stroke Sequence Lengths:
    Sequence 1 length: 1160
    Sequence 2 length: 1160
    Sequence 3 length: 1160
    Sequence 4 length: 1160
    Sequence 5 length: 1160
    Sequence 6 length: 1160
    Sequence 7 length: 1160
    Sequence 8 length: 1160
    Sequence 9 length: 1160
    Sequence 10 length: 1160
    Sequence 11 length: 1160
    Sequence 12 length: 1160
    Sequence 13 length: 1160
    Sequence 14 length: 1160
    Sequence 15 length: 1160
    Sequence 16 length: 1160
Batch 3:
  Stroke 

In [72]:
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=collate_fn)
test_padding(val_loader)

Batch 1:
  Stroke Sequence Lengths:
    Sequence 1 length: 501
    Sequence 2 length: 501
    Sequence 3 length: 501
    Sequence 4 length: 501
    Sequence 5 length: 501
    Sequence 6 length: 501
    Sequence 7 length: 501
    Sequence 8 length: 501
    Sequence 9 length: 501
    Sequence 10 length: 501
    Sequence 11 length: 501
    Sequence 12 length: 501
    Sequence 13 length: 501
    Sequence 14 length: 501
    Sequence 15 length: 501
    Sequence 16 length: 501
Batch 2:
  Stroke Sequence Lengths:
    Sequence 1 length: 500
    Sequence 2 length: 500
    Sequence 3 length: 500
    Sequence 4 length: 500
    Sequence 5 length: 500
    Sequence 6 length: 500
    Sequence 7 length: 500
    Sequence 8 length: 500
    Sequence 9 length: 500
    Sequence 10 length: 500
    Sequence 11 length: 500
    Sequence 12 length: 500
    Sequence 13 length: 500
    Sequence 14 length: 500
    Sequence 15 length: 500
    Sequence 16 length: 500
Batch 3:
  Stroke Sequence Lengths:
    Sequence 1

In [73]:
for i, (data, label) in enumerate(val_loader):
    if data[0].shape[0] < 100:
        continue
    print("First Data Element:", data[0].shape)
    print("First Label:", label[0].shape)
    if i > 3 * BATCH_SIZE:
        break

First Data Element: torch.Size([501, 3])
First Label: torch.Size([33, 78])
First Data Element: torch.Size([500, 3])
First Label: torch.Size([35, 78])
First Data Element: torch.Size([498, 3])
First Label: torch.Size([37, 78])
First Data Element: torch.Size([497, 3])
First Label: torch.Size([39, 78])
First Data Element: torch.Size([496, 3])
First Label: torch.Size([33, 78])
First Data Element: torch.Size([494, 3])
First Label: torch.Size([34, 78])
First Data Element: torch.Size([492, 3])
First Label: torch.Size([36, 78])
First Data Element: torch.Size([491, 3])
First Label: torch.Size([35, 78])
First Data Element: torch.Size([490, 3])
First Label: torch.Size([32, 78])
First Data Element: torch.Size([488, 3])
First Label: torch.Size([35, 78])
First Data Element: torch.Size([487, 3])
First Label: torch.Size([32, 78])
First Data Element: torch.Size([486, 3])
First Label: torch.Size([33, 78])
First Data Element: torch.Size([484, 3])
First Label: torch.Size([31, 78])
First Data Element: torch

In [69]:
val_dataset[2][0].shape

torch.Size([501, 3])