In [None]:
import numpy as np
import torch
from EnhancerDataset import EnhancerDataset

def enhancer_collate_fn(batch):
    """This pads the data to the max sequence length in the batch to reduce memory usage during training
        Returns the sequences as padded, mask tensor for attention parameter, and labels
    """
    sequences, labels = zip(*batch) # sequences.shape = (B, N, embed_dim)

    max_len = max(len(seq) for seq in sequences)
    embed_dim = len(sequences[0][0])

    padded = []
    masks  = []

    for seq in sequences:
        L = len(seq)
        pad_len = max_len - L

        seq = np.array(seq, dtype=np.float32)
        pad = np.zeros((pad_len, embed_dim), dtype=np.float32)

        padded_seq = np.concatenate([seq, pad], axis=0)
        padded.append(padded_seq)

        # mask: 1 for real token, 0 for padded
        mask = np.concatenate([np.ones(L), np.zeros(pad_len)])
        masks.append(mask)

    padded = torch.tensor(np.stack(padded), dtype=torch.float32) # (B, L, E)
    masks  = torch.tensor(np.stack(masks), dtype=torch.bool) # (B, L)
    labels = torch.tensor(labels, dtype=torch.float32)

    return padded, masks, labels

In [None]:
from torch.utils.data.dataloader import DataLoader
import EnhancerDataset

enhancer_data = EnhancerDataset(data_type='train')

dataloader = DataLoader(enhancer_data, batch_size=16, shuffle=True, collate_fn=enhancer_collate_fn)

for batch_padded, batch_mask, batch_labels in dataloader:
    print(len(batch_labels))

TypeError: tensor() got an unexpected keyword argument 'dype'