In [None]:
path = "5_marked.edf"

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pyedflib
import locale, pprint, time, calendar
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle 

читаем edf файл

In [None]:
import pyedflib
edf_file_path = path
edf_file = pyedflib.EdfReader(edf_file_path)


num_channels = edf_file.signals_in_file
channel_labels = edf_file.getSignalLabels()
sample_rate = edf_file.getSampleFrequency(0)

n_annotations = edf_file.annotations_in_file
annotations=[]
epi={}
for i in range(num_channels):
    epi[channel_labels[i]] = edf_file.readSignal(i)
    
annotation = edf_file.read_annotation()
edf_file.close()

In [None]:
sample_rate = edf_file.getSampleFrequency(19)
sample_rate

In [None]:
def get_sec(time):
    h, m, s = time.split(':')
    return int(h)*3600+int(m)*60+int(s)

In [None]:
fig, axs = plt.subplots(23, 1, figsize=(30,40), dpi=400)
keys_epi=list(epi.keys())
example=[]
start = int(get_sec('0:10:00')*sample_rate)
finish = int(get_sec('0:20:20')*sample_rate)
for i in range(len(keys_epi)):
    axs[i].plot(epi[keys_epi[i]][start:finish])
    axs[i].set_title(f'{channel_labels[i]}')
    example.append(epi[keys_epi[i]][start:finish])

Определяем структуру нейросети

In [None]:
class UNetConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetConv, self).__init__()
        self._model = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=9, padding=4),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(),
            nn.Conv1d(out_channels, out_channels, kernel_size=9, padding=4),
            nn.BatchNorm1d(out_channels),
            nn.ReLU()
        )
    
    def forward(self, X):
        return self._model(X)
    

class UNetDown(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetDown, self).__init__()
        self._model = nn.Sequential(
            nn.MaxPool1d(2),
            UNetConv(in_channels, out_channels)
        )
    
    def forward(self, X):
        return self._model(X)
    

class UNetUp(nn.Module):
    def __init__(self, in_channels, in_channels_skip, out_channels):
        super(UNetUp, self).__init__()
        self._up = nn.ConvTranspose1d(in_channels, in_channels, kernel_size=8, stride=2, padding=3)
        self._model = UNetConv(in_channels + in_channels_skip, out_channels)
    
    def forward(self, X_skip, X):
        X = self._up(X)  
        diff = X_skip.size()[2] - X.size()[2]
        X = F.pad(X, (diff // 2, diff - diff // 2))  
        return self._model(torch.cat([X_skip, X], dim=1))

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(UNet, self).__init__()
        n = 2
        self._input = UNetConv(in_channels, n)
        self._down1 = UNetDown(n, 2*n)
        self._down2 = UNetDown(2*n, 4*n)
        self._down3 = UNetDown(4*n, 8*n)
        self._down4 = UNetDown(8*n, 16*n)
        self._up1 = UNetUp(16*n, 8*n, 8*n)
        self._up2 = UNetUp(8*n, 4*n, 4*n)
        self._up3 = UNetUp(4*n, 2*n, 2*n)
        self._up4 = UNetUp(2*n, n, n)
        self._output = nn.Conv1d(n, num_classes, kernel_size=1)
        
    def forward(self, X):
        x1 = self._input(X)
        x2 = self._down1(x1)
        x3 = self._down2(x2)
        x4 = self._down3(x3)
        x = self._down4(x4)
        x = self._up1(x4, x)
        x = self._up2(x3, x)
        x = self._up3(x2, x)
        x = self._up4(x1, x)
        return self._output(x)

In [None]:
model = UNet(1, 2)

Загружаем веса

In [None]:
model.load_state_dict(torch.load("good_model_020824.pt"))

Вспомогательные функции для маски

In [None]:
v_to_del = {1:'qrs'}

def remove_small(signal):
    max_dist = 50
    last_zero = 0
    for i in range(len(signal)):
        if signal[i] == 0:
            if i - last_zero < max_dist:
                signal[last_zero:i] = 0
            last_zero = i

def merge_small(signal):
    max_dist = 50
    lasts = np.full(signal.max() + 1, -(max_dist+1))
    for i in range(len(signal)):
        m = signal[i]
        if i - lasts[m] < max_dist and m > 0:
            signal[lasts[m]:i] = m
        lasts[m] = i

def mask_to_delineation(mask):
    merge_small(mask)
    remove_small(mask)
    delineation = {'qrs':[]}
    i = 0
    mask_length = len(mask)
    while i < mask_length:
        v = mask[i]
        if v > 0:
            delineation[v_to_del[v]].append([i, 0])
            while i < mask_length and mask[i] == v:
                delineation[v_to_del[v]][-1][1] = i
                i += 1
            t = delineation[v_to_del[v]][-1]
        i += 1
    return delineation

In [None]:
wave_type_to_color = {
    "qrs": "red"
}

def plot_signal_with_mask(signal, mask, i):
    begs = []
    times = []
    plt.figure(figsize=(36, 10), dpi=300)
    plt.title(f"{keys_epi[i]}")
    plt.xlabel("Время (сек)")
    plt.ylabel("Амплитуда (мВ)")
    x_axis_values = np.linspace(0, len(signal) / sample_rate, len(signal))
    plt.plot(x_axis_values, signal, linewidth=2, color="black")
    
    delineation = mask_to_delineation(mask)
    for wave_type in ["qrs"]:
        color = wave_type_to_color[wave_type]
        for begin, end in delineation[wave_type]:
            print(begin, end)
            begs.append(begin)
            times.append(end - begin)
            begin /= sample_rate
            end /= sample_rate
            plt.axvspan(begin, end, facecolor=color, alpha=0.5)
    # plt.savefig(f"{i}.pdf")
    return begs, times

In [None]:
def get_mask(signal):
    signal = np.expand_dims(signal, axis=(0, 1))
    signal = torch.FloatTensor(signal)
    mask = model(signal)[0]
    mask = mask.max(axis=0)[1]
    mask[:500] = 0
    mask[-500:] = 0
    return mask.data.numpy()

In [None]:
def plot_test_sample(signal, i):
    # signal = signals_test[index]
    # true_mask = masks_test[index]

    mask = get_mask(signal)
    beg, end = plot_signal_with_mask(signal, mask, i)
    return beg, end
    plot_signal_with_mask(signal, true_mask)

In [None]:
len(example)

In [None]:
channel_labels[10]

In [None]:
plot_test_sample(example[10][5000:10000], 10)

Размечаем нужный участок

In [None]:
annots_begs={}
annots_times={}

In [None]:
beg, time = plot_test_sample(example[10], 10)
annots_begs[10] = beg
annots_times[10] = time

In [None]:
for j in range(len(annots_begs[10])):
    annots_begs[10][j] = round((annots_begs[10][j]+start) / int(sample_rate))
    annots_times[10][j] /= int(sample_rate)

In [None]:
annots_begs[10]

In [None]:
len(annots_begs[10])

In [None]:
annots_times[10]

In [None]:
labels={}
labels[10] = []
for j in range(len(annots_begs[10])):
    labels[10].append(f'{keys_epi[10]}')

In [None]:
annot_begs = annots_begs[10]
annot_durs = annots_times[10]
annot_labels = labels[10]

формируем аннотацию

In [None]:
annotation = [annot_begs, annot_durs, annot_labels]

In [None]:
annotation

пишем новую аннотацию в файл

In [None]:
import mne

In [None]:
data = mne.io.read_raw_edf(path, encoding='latin1')

In [None]:
annotations = mne.Annotations(np.array(annotation[0]), np.array(annotation[1]), np.array(annotation[2])) # тот же формат аргументов

In [None]:
data.set_annotations(annotations)
data.annotations

In [None]:
data.export("nn_"+path, overwrite=True)