In [1]:
import torch.multiprocessing as mp
import torch


def mask_batch_subset(x_subset, mask_mtd, mask_ratio):
    seq_lenth = len(x_subset[0])
    num_mask = int(1 + mask_ratio * (seq_lenth - 2))
    weights = torch.ones(x_subset.shape[1]).expand(x_subset.shape[0], -1)
    idx = torch.multinomial(weights, num_mask, replacement=False)
    if mask_mtd == "zeros":
        masked_tensor = torch.zeros(x_subset.shape[2], x_subset.shape[3],
                                    x_subset.shape[4]).to(x_subset.device)
    elif mask_mtd == "random":
        masked_tensor = torch.rand(x_subset.shape[2], x_subset.shape[3],
                                   x_subset.shape[4]).to(x_subset.device)
    batch_indices = torch.arange(x_subset.shape[0],
                                 device=x_subset.device).unsqueeze(1).expand(
                                     -1, num_mask)
    x_subset[batch_indices, idx] = masked_tensor
    return x_subset, idx


def parallel_mask(x, mask_mtd="zeros", test_flag=False, mask_ratio=None, num_processes=8):
    if test_flag == False:
        mask_ratio = torch.rand(1).item()

    # Split the input tensor into subsets for each CPU
    batch_size = x.shape[0]
    subsets = torch.chunk(x, num_processes, dim=0)

    # Prepare arguments for each process
    args = [(subset, mask_mtd, mask_ratio) for subset in subsets]

    # Use multiprocessing to parallelize the mask operation
    with mp.Pool(processes=num_processes) as pool:
        results = pool.starmap(mask_batch_subset, args)

    # Combine the results back into the original shape
    masked_subsets, indices = zip(*results)
    masked_x = torch.cat(masked_subsets, dim=0)
    
    return masked_x, indices

In [2]:
import time
# Example usage:
x = torch.rand(200, 10, 3, 224, 224)  # Example tensor [batchsize, sequence_length, channels, height, width]

In [3]:
# Example usage:
start_time = time.time()  # Start time measurement

masked_x, indices = parallel_mask(x)  # Function call you want to measure

end_time = time.time()  # End time measurement

time_consumed = end_time - start_time  # Calculate the time difference

print(f"Time consumed: {time_consumed} seconds")

Time consumed: 2.049285888671875 seconds


In [4]:
def mask(x, mask_mtd="zeros", test_flag=False, mask_ratio=None):
    seq_lenth = len(x[0])
    if test_flag == False:
        mask_ratio = torch.rand(1).item()
    else:
        mask_ratio = mask_ratio
    num_mask = int(1 + mask_ratio * (seq_lenth - 2))
    weights = torch.ones(x.shape[1]).expand(x.shape[0], -1)
    idx = torch.multinomial(weights, num_mask, replacement=False)
    if mask_mtd == "zeros":
        masked_tensor = torch.zeros(x.shape[2], x.shape[3],
                                    x.shape[4]).to(x.device)
    elif mask_mtd == "random":
        masked_tensor = torch.rand(x.shape[2], x.shape[3],
                                   x.shape[4]).to(x.device)
    batch_indices = torch.arange(x.shape[0],
                                 device=x.device).unsqueeze(1).expand(
                                     -1, num_mask)
    x[batch_indices, idx] = masked_tensor
    return x, idx

In [5]:
# Example usage:
start_time = time.time()  # Start time measurement

masked_x, indices = mask(x)  # Function call you want to measure

end_time = time.time()  # End time measurement

time_consumed = end_time - start_time  # Calculate the time difference

print(f"Time consumed: {time_consumed} seconds")

Time consumed: 0.01887226104736328 seconds
