In [1]:
import torch

def apply_mask_to_tensor(X, mask_length, set_nth_to_zero=True):
    '''
This function takes a tensor, and a mask length integer, and an additional parameter set_nth_to_zero.
It returns an amended tensor.
mask is created based on a specified length. 
If set_nth_to_zero is True, function sets every nth element to zero.
If False, function sets all elements but nth to zero.

order of decimation would be:

    params      result
--------------------------------
    4,True:     1,2,3,0...  every 4th element set to zero
    3,True:     1,2,0,4...  every 3rd element set to zero
    2,True:     1,0,3,0...  every 2nd element set to zero
    2,False:    0,2,0,4...  (similar to '2,True', but invert)
    3,False:    0,0,3,0...  2 of 3 elements set to zero
    4,False:    0,0,0,4...  3 of 4 elements set to zero
'''

# Create a mask pattern with n-1 False values and 1 True value at the end
# If 'set_nth_to_zero' is False, INVERT the mask pattern
    mask_pattern = [False] * (mask_length - 1) + [True]
    if not set_nth_to_zero:
        mask_pattern = [not elem for elem in mask_pattern]

    # Create a mask (a tensor) with a multiple of n columns 
    mask = torch.tensor(mask_pattern * (X.shape[1] // mask_length))

    # Append the necessary number of elements to the mask to match the number of columns in X
    mask = torch.cat((mask, torch.tensor(mask_pattern[:X.shape[1] % mask_length])))

    # Apply the mask to each row of tensor X
    for i in range(X.shape[0]):
        X[i, mask] = 0
    
    return X # Return the amended tensor


# Example usage
X = torch.tensor([[1, 2, 3, 8,7,6,5], [6, 1,2,3,8, 9, 10], [11 ,16,17,18,19, 13, 14]])
print(X)
mask_length = 4 # must be an integer greater than 1.

X = apply_mask_to_tensor(X, mask_length, False)
print(X)

tensor([[ 1,  2,  3,  8,  7,  6,  5],
        [ 6,  1,  2,  3,  8,  9, 10],
        [11, 16, 17, 18, 19, 13, 14]])
tensor([[ 0,  0,  0,  8,  0,  0,  0],
        [ 0,  0,  0,  3,  0,  0,  0],
        [ 0,  0,  0, 18,  0,  0,  0]])
