In [1]:
import torch
from torch.fft import fft, ifft, fftshift  # use torch instead of scipy for speed
from scipy.signal import butter, filtfilt
import numpy as np

In [2]:
!nvidia-smi

Fri Jul  5 17:05:06 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:1A:00.0 Off |                  Off |
| N/A   33C    P0    41W / 300W |      3MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  Off  | 00000000:1C:00.0 Off |                  Off |
| N/A   32C    P0    40W / 300W |      3MiB / 16384MiB |      0%      Default |
|       

In [3]:
def get_highpass_filter(fs=30000, cutoff=300, device=torch.device("cuda")):
    """Filter to use for high-pass filtering."""
    NT = 30122

    # a butterworth filter is specified in scipy
    b, a = butter(3, cutoff, fs=fs, btype="high")

    # a signal with a single entry is used to compute the impulse response
    x = np.zeros(NT)
    x[NT // 2] = 1

    # symmetric filter from scipy
    hp_filter = filtfilt(b, a, x).copy()

    hp_filter = torch.from_numpy(hp_filter).to(device).float()
    return hp_filter


def get_fwav(NT=30122, fs=30000, device=torch.device("cuda")):
    """Precomputes a filter to use for high-pass filtering.

    To be used with fft in pytorch. Currently depends on NT,
    but it could get padded for larger NT.

    """

    # a butterworth filter is specified in scipy
    b, a = butter(3, 300, fs=fs, btype="high")

    # a signal with a single entry is used to compute the impulse response
    x = np.zeros(NT)
    x[NT // 2] = 1

    # symmetric filter from scipy
    wav = filtfilt(b, a, x).copy()
    wav = torch.from_numpy(wav).to(device).float()

    # the filter will be used directly in the Fourier domain
    fwav = fft(wav)

    return fwav


def fft_highpass(hp_filter, NT=30122):
    """Convert filter to fourier domain. (from Pachitariu's Kilosort4)"""
    device = hp_filter.device
    ft = hp_filter.shape[0]

    # the filter is padded or cropped depending on the size of NT
    if ft < NT:
        pad = (NT - ft) // 2
        fhp = fft(
            torch.cat(
                (
                    torch.zeros(pad, device=device),
                    hp_filter,
                    torch.zeros(pad + (NT - pad * 2 - ft), device=device),
                )
            )
        )
    elif ft > NT:
        crop = (ft - NT) // 2
        fhp = fft(hp_filter[crop : crop + NT])
    else:
        fhp = fft(hp_filter)

    return fhp

In [23]:
torch.cuda.empty_cache()

In [26]:
# (7s)
n_sites = 384
sfreq = 40000
t_secs = 10
n_samples = t_secs * sfreq
gain = 1000

# simulate toy voltage signal (sites x samples)
# int16 to reduce memory load
device = torch.device("cuda:0")
signal = torch.normal(0, 3, size=(n_sites, n_samples), device=device) * gain

In [19]:
# remove the mean of each channel (row), and the median across channels
signal -= signal.mean(1).unsqueeze(1)
signal -= torch.median(signal, 0)[0]

In [20]:
hp_filter = get_highpass_filter(fs=sfreq, cutoff=300, device=torch.device("cuda"))

fwav = get_fwav(NT=signal.shape[1], fs=sfreq, device=torch.device("cuda"))

signal = torch.real(ifft(fft(signal) * torch.conj(fwav)))
signal = fftshift(signal, dim=-1)

In [21]:
signal

tensor([[-1136.5408,  2438.9475, -1039.4617,  ..., -1049.9680, -1091.5852,
           551.5223],
        [ 3074.2878, -2892.0371,   147.3754,  ...,  2721.6379,  2855.7241,
          5758.3999],
        [ 5091.6714,  4180.6211,   533.6821,  ...,  1287.2253,  2321.1328,
         -4523.5352],
        ...,
        [-1775.0090,   954.2173, -1634.2181,  ...,  3620.5918,  -777.2347,
          4115.3276],
        [ 2659.4995,   525.9730,  1112.8335,  ...,  -509.4645,   114.5115,
          1000.7498],
        [ -259.1798, -2103.2576, -5756.4375,  ..., -6795.0771,  3380.0742,
         -2601.6870]], device='cuda:0')

In [8]:
signal.shape

torch.Size([384, 2400000])