In [None]:
import numpy as np
import pandas as pd
#import wfdb
import os
from glob import glob
from matplotlib import pyplot as plt
from sklearn.model_selection import KFold
from sklearn.metrics import classification_report,confusion_matrix
from sklearn.metrics import f1_score,accuracy_score,precision_score,recall_score
import scipy.signal as signal
from scipy.signal import butter, lfilter
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.alexnet import alexnet
seed=2020
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True

In [None]:
reclist=glob("../apnea/seg/*")
dtlist=glob("../apnea/seg/*/*.npy*")
reclist.sort()
dtlist.sort()
lab_list=[seg.split('_')[0][-1] for seg in dtlist]#shuffle by idx
print(lab_list.count('A')/len(lab_list),lab_list.count('N')/len(lab_list))

reclist[0:5]

In [None]:
def cheb_bandpass_filter(data, lowcut, highcut, signal_freq, filter_order):
        """
        Method responsible for creating and applying Butterworth filter.
        :param deque data: raw data
        :param float lowcut: filter lowcut frequency value
        :param float highcut: filter highcut frequency value
        :param int signal_freq: signal frequency in samples per second (Hz)
        :param int filter_order: filter order
        :return array: filtered data
        """
        nyquist_freq = 0.5 * signal_freq
        low = lowcut / nyquist_freq
        high = highcut / nyquist_freq
        #b, a = butter(filter_order, [low, high], btype="band")
        b, a = signal.cheby2(filter_order, 40, [low, high], 'band', analog=False)
        y = lfilter(b, a, data)
        return y

In [None]:
#Preprocessing
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
import numpy as np
bw2=3
filetmp=dtlist[1925]#5
dt=np.load(filetmp)
f=plt.figure(dpi=512)
f.subplots_adjust(hspace=0.2)
ax=plt.subplot(3,2,1)
plt.plot(dt)
#plt.xlim(0,6000)
plt.ylim(-1,4)
plt.xticks([])
ax=plt.gca()
ax.yaxis.set_ticks_position('right')
plt.xlabel('0-60 seconds',rotation=0,ha='center',va='top',font='Times New Roman')

ax=plt.subplot(3,3,3)
plt.ylim(-1,4)
plt.plot(dt[0:500])
plt.xticks([])
plt.ylabel('Voltage ($mv$)',rotation=0,ha='right',va='top',font='Times New Roman')
plt.xlabel('0-5 seconds',rotation=0,ha='center',va='top',font='Times New Roman')

#plt.show()
#plt.close("all")

dt=cheb_bandpass_filter(dt, 0.01, 38, 100, 4)
ax=plt.subplot(3,2,3)
plt.plot(dt)
#plt.xlim(0,6000)
plt.ylim(-1,4)
plt.xticks([])
ax=plt.gca()
ax.yaxis.set_ticks_position('right')
plt.xlabel('0-60 seconds',rotation=0,ha='center',va='top',font='Times New Roman')

ax=plt.subplot(3,3,6)
plt.ylim(-1,4)
plt.plot(dt[0:500])
plt.xticks([])
plt.ylabel('Voltage ($mv$)',rotation=0,ha='right',va='top',font='Times New Roman')
plt.xlabel('0-5 seconds',rotation=0,ha='center',va='top',font='Times New Roman')


f, t, Sxx=signal.spectrogram(dt, fs=100.0, window=('hamming'), nperseg=128, noverlap=64, nfft=128, detrend='constant',
return_onesided=True, scaling='density', axis=-1, mode='psd')
input=torch.FloatTensor(Sxx[0:26]).unsqueeze(dim=0)
# with torch.no_grad():
#     pred_prob_no_softmax,_=model(torch.FloatTensor(input))


ax=plt.subplot(3,1,3)
plt.imshow(Sxx,aspect='auto',interpolation='nearest')
plt.plot([26]*92,'r--',linewidth=3)
ax=plt.gca()
ax.yaxis.set_ticks_position('left')
plt.xlim(0,91)
plt.ylim(0,50)
plt.xticks([])
plt.yticks([])
plt.xlabel('0-60 seconds',rotation=0,ha='center',va='top',font='Times New Roman')
#plt.ylabel('Frequency ($Hz$)',rotation=0,ha='right')
cb=plt.colorbar(orientation='vertical',fraction=0.2)
plt.tight_layout()
plt.savefig('RawApneaScheme.pdf',dpi=512)
plt.show()
plt.close("all")
print(filetmp)

In [None]:
subjects=[['a11'],['a15'],['a17'],['b01'],['c07'],
          ['a11','a15','a17','b01','c07'],
          ['a14','a19','b05','c01','c07'],
          ['a04','a19','b05','c01','c09'],['b02','b03'],['x16','x21']]
len(subjects)

In [None]:
from sklearn.model_selection import KFold
def getTrainTestList(reclist,opt='rec_cv',fold=1):
    train_dtlist=[]
    test_dtlist=[]
    if opt=='rec_cv':#reclist is rec list
        kf = KFold(n_splits=10)
        kf_idx=kf.split(reclist)
        for fold_idx in range(fold):#1,2,3,4,5
            train_idx,test_idx=kf_idx.__next__()
        train_rec,test_rec=list(np.array(reclist)[train_idx]),list(np.array(reclist)[test_idx])
        for rec in train_rec:
            train_dtlist+=glob(rec+'/*.npy*')
        for rec in test_rec:
            test_dtlist+=glob(rec+'/*.npy*')
    elif opt=='physionet':
        train_rec=[rec for rec in reclist if rec[-3]!='x']
        test_rec=[rec for rec in reclist if rec[-3]=='x']
        print(len(train_rec),len(test_rec))
        for rec in train_rec:
            train_dtlist+=glob(rec+'/*.npy*')
        for rec in test_rec:
            test_dtlist+=glob(rec+'/*.npy*')
    elif opt=='physionet_train':
        train_rec=[rec for rec in reclist if rec[-3]!='x']
        #print(len(train_rec),len(test_rec))
        kf_idx=kf.split(dtlist)
        for fold_idx in range(fold):#1,2,3,4,5
            train_idx,test_idx=kf_idx.__next__()
        train_dtlist,test_dtlist=list(np.array(dtlist)[train_idx]),list(np.array(dtlist)[test_idx])
    elif opt=='blind':
        subjects=[['a11'],
                  ['a15','x27','x28'],
                  ['a17','x12'],
                  ['b01','x03'],
                  ['c07','x34'],
                  ['a11','a15','x27','x28','a17','x12','b01','x03','c07','x34'],
                  ['a14','a19','x05','x08','x25','b05','x11','c01','x35','c07','x34'],
                  ['a04','a19','x05','x08','x25','b05','x11','c01','x35','c09'],
                  ['b02','b03','x16','x21']]
        train_rec=[rec for rec in reclist if rec.split('/')[-1] not in subjects[fold-1]]
        test_rec=[rec for rec in reclist if rec.split('/')[-1] in subjects[fold-1]]
        print(len(train_rec),len(test_rec),test_rec)
        for rec in train_rec:
            train_dtlist+=glob(rec+'/*.npy*')
        for rec in test_rec:
            test_dtlist+=glob(rec+'/*.npy*')

    else:#reclist is npy data list
        kf = KFold(n_splits=10)
        kf_idx=kf.split(dtlist)
        for fold_idx in range(fold):#1,2,3,4,5
            train_idx,test_idx=kf_idx.__next__()
        train_dtlist,test_dtlist=list(np.array(dtlist)[train_idx]),list(np.array(dtlist)[test_idx])
         
    return train_dtlist,test_dtlist

def dtclean(dt_path):
    dt=np.load(dt_path)
    if dt.std()<0.1:
        print(dt_path)
        return 1      
    else:
        return 0
#train_rec,test_rec=reclist[train_idx],reclist[test_idx]
train_dtlist,test_dtlist=getTrainTestList(reclist,opt='kfold',fold=6)#fold 1-10
print('*********')
print(len(train_dtlist),len(test_dtlist))
train_dtlist=[file for file in train_dtlist if dtclean(file)==0]
np.random.shuffle(train_dtlist)
test_dtlist=[file for file in test_dtlist if dtclean(file)==0]
print(len(train_dtlist),len(test_dtlist))

In [None]:
import numpy as np
from torch.utils.data import DataLoader,Dataset
ApneaECGDict={'N':0,
              'A':1}
class ApneaECGDataset(Dataset):
    def __init__(self, filelist, istrain=False):
        # Get the filelist and img data
        self.filelist = filelist
        self.istrain = istrain
        
    def getFeature(self, dt):

        dt=cheb_bandpass_filter(dt, 0.01, 38, 100, 4)
        f, t, Sxx=signal.spectrogram(dt, fs=100.0, window=('hamming'), nperseg=128, noverlap=64, nfft=128, detrend='constant',
        return_onesided=True, scaling='density', axis=-1, mode='psd')
        return Sxx[0:26]

    def __getitem__(self, index):
        # return to the data of a Picture
        dt_path = self.filelist[index]
        label = ApneaECGDict.__getitem__(dt_path.split('_')[0][-1])
        data = np.load(dt_path)

        if self.istrain:
            noise = np.random.normal(0, 0.1, data.shape[0])
            data = noise+data
        data=self.getFeature(data)
        return data, label

    def __len__(self):
        return len(self.filelist)

    
class MyLSTM(nn.Module):
    def __init__(self, class_num,fs):
        super(MyLSTM, self).__init__()
#25 * 65
#12* 32
#b*64*6*16
#b*6*64*16
        self.branch1=nn.Sequential(nn.Conv2d(1,64,kernel_size=3,dilation=1,stride=1,padding=1, bias=False),
                                nn.Conv2d(64,64,kernel_size=3,dilation=1,stride=1,padding=1, bias=False),                             
                                nn.LayerNorm([64,26,92],elementwise_affine=False),
                                nn.ReLU(), 
                                nn.Dropout(p=0.25),
                                nn.MaxPool2d(kernel_size=2,stride=2),

                                nn.LayerNorm([64,13,46],elementwise_affine=False),
                                nn.ReLU(), 
                                nn.Dropout(p=0.5),
                                nn.Conv2d(64,64,kernel_size=3,dilation=1,stride=1,padding=1, bias=False),
                                                          
                                nn.MaxPool2d(kernel_size=2,stride=2),
                                nn.LayerNorm([64,6,23],elementwise_affine=False),
                                nn.ReLU(),
                                nn.Dropout(p=0.5),
                                #nn.Conv2d(64,64,kernel_size=(1,3),dilation=1,stride=1,padding=0, bias=False),
                                nn.Conv2d(64,64,kernel_size=3,dilation=1,stride=1,padding=1, bias=False))
        
        
        self.avg = nn.Sequential(nn.AdaptiveMaxPool2d((8,8)),nn.LayerNorm([6,8,8],elementwise_affine=False))
    
        self.lstm = nn.LSTM(64, 8, 2,
                            bias=False,
                            batch_first=True,
                            dropout=0.5,
                            bidirectional=True)
        #self.conv2=nn.Sequential(nn.Conv1d(64,64,kernel_size=1,stride=1,padding=0, bias=False))
        
        self.fc=nn.Sequential(nn.Linear(16*2,class_num))
    def attention_net(self,lstm_output, final_state):
        hidden = final_state.view(-1, 16, 2)   # hidden : [batch_size, n_hidden * num_directions(=2), 1(=n_layer)]
        #print('hi',hidden.shape)
        attn_weights = torch.bmm(lstm_output, hidden).squeeze(2) # attn_weights : [batch_size, n_step]
        #print('att',attn_weights.shape)
        soft_attn_weights = F.softmax(attn_weights, 1)

        # [batch_size, n_hidden * num_directions(=2), n_step] * [batch_size, n_step, 1] = [batch_size, n_hidden * num_directions(=2), 1]
        context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights)
        return context, soft_attn_weights.data.numpy() # context : [batch_size, n_hidden * num_directions(=2)]
    def forward(self, x):
        x=x.unsqueeze(dim=1)#x.reshape(x.shape[0],1,-1)
        out=self.branch1(x)
        out=out.permute(0,2,1,3)
        out=self.avg(out)
        out=out.view(out.shape[0],out.shape[1],-1)
        out,(h,c) = self.lstm(out)
        out,attn=self.attention_net(out, h)

        out = self.fc(out.reshape(out.shape[0],-1))
        return out,attn

In [None]:
model=MyLSTM(2,100)
#model=alexnet(num_classes=2)
#model.features[0]=nn.Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
for m in model.modules():
    if isinstance(m, nn.Conv2d):
        if(m.in_channels!=m.out_channels or m.out_channels!=m.groups or m.bias is not None):
            # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics
            #nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            nn.init.normal_(m.weight, mean=0, std=0.1)
            print(m,'init')
        else:
            print('Not initializing')
criterion=nn.CrossEntropyLoss()#weight=torch.FloatTensor([1,1.5])
#criterion = CircleLoss(m=0.25, gamma=256)
opt=torch.optim.Adam(model.parameters(), lr=0.0003, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.001)
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt,T_max=20,eta_min=0.000005)
loss_list=[]

In [None]:
def eval(model,test_dtlist,criterion):
    test_dataset = ApneaECGDataset(test_dtlist,istrain=False)
    test_loader=DataLoader(test_dataset, batch_size=256, shuffle=False, sampler=None, num_workers=0)
    #train_dataset = ApneaECGDataset(train_dtlist,istrain=False)
        #test_dataset = ApneaECGDataset(test_dtlist,istrain=False)
    #train_loader=DataLoader(train_dataset, batch_size=256, shuffle=True, sampler=None, num_workers=0)
    sft=nn.Softmax()
    all_pred_prob=torch.Tensor([])
    all_label=torch.Tensor([]).long()
    flag=0
    with torch.no_grad():
        for fe, label in test_loader:
            fe=fe.float()
            #fe=(fe-fe_mean)/fe_std
            label=label.long()
            pred_prob_no_softmax,_=model(fe)
            all_pred_prob=torch.cat((all_pred_prob,pred_prob_no_softmax),0)
            all_label=torch.cat((all_label,label),0)
            
            loss=criterion(pred_prob_no_softmax,label)
            flag+=1
            if flag%8==0:
                print('Eval Loss: ',loss.item())
            #pos=(sft(pred_prob_no_softmax)[:,1]>0.5)
            #print("Acc: ", sum(label==pos).detach().numpy()/len(label))
    all_pred=sft(all_pred_prob)[:,1].detach().numpy()
    all_pred[all_pred>0.5]=1
    all_pred[all_pred<=0.5]=0
    all_label=all_label.detach().numpy()

    print(confusion_matrix(all_label,all_pred))
    print(classification_report(all_label,all_pred))
    print("acc: ",accuracy_score(all_label,all_pred))
    print("pre: ",precision_score(all_label,all_pred))
    print("rec: ",recall_score(all_label,all_pred))
    print("ma F1: ",f1_score(all_label,all_pred, average='macro'))
    print("mi F1: ",f1_score(all_label,all_pred, average='micro'))
    print("we F1: ",f1_score(all_label,all_pred, average='weighted'))
    return accuracy_score(all_label,all_pred)

In [None]:
# Use a good model for data visualization
ma=0.8
opt.zero_grad()
epoch_flag=0
for epoch in range(26):
    print('epoch: ', epoch)
    if epoch_flag%3==0:# and epoch_flag!=0:
        model.eval()
        atmp=eval(model,test_dtlist,criterion)
        #eval(model,train_dtlist[0:int(len(train_dtlist)*0.9)],criterion)
        tmp=eval(model,train_dtlist,criterion)
        
        if atmp >ma:
            ma=atmp.copy()
            torch.save(model.state_dict(),'param.pkl')
            print(atmp,ma)
        model.train()
    train_dataset = ApneaECGDataset(train_dtlist,istrain=False)
    #test_dataset = ApneaECGDataset(test_dtlist,istrain=False)
    train_loader=DataLoader(train_dataset, batch_size=256, shuffle=True, sampler=None, num_workers=0)
    #test_loader=DataLoader(test_dataset, batch_size=256, shuffle=True, sampler=None, num_workers=0)
    len(train_dataset),len(train_loader)#,len(test_dataset)
    flag=0
    for fe, label in train_loader:
        fe=fe.float()
        #fe=(fe-fe_mean)/fe_std
        #fe=torch.stft(fe,100,hop_length=50,onesided=True)
        
        label=label.long()
        pred_prob_no_softmax,_=model(fe)
        loss=criterion(pred_prob_no_softmax,label)#/256
        
        #inp_sp, inp_sn = convert_label_to_similarity(pred_prob_no_softmax,label)
        #loss = criterion(inp_sp, inp_sn)/256

        loss_list.append(loss.item())
        loss.backward()
        opt.step()
        opt.zero_grad()
        flag+=1
        if flag%4==0:
            print("Loss: ", loss.item())
        
        #print('step')
    print(epoch_flag,'*'*10)
    epoch_flag+=1
    #scheduler.step()

plt.plot(loss_list)
plt.show()
plt.close('all')
print('------')

In [None]:
model.load_state_dict(torch.load('param.pkl'))
train_dtlist=[file for file in dtlist if (dtclean(file)==0) and file.split('/')[-1][0]=='A']
train_dataset = ApneaECGDataset(train_dtlist,istrain=False)
#test_dataset = ApneaECGDataset(test_dtlist,istrain=False)
train_loader=DataLoader(train_dataset, batch_size=256, shuffle=False, sampler=None, num_workers=0)
#test_loader=DataLoader(test_dataset, batch_size=256, shuffle=True, sampler=None, num_workers=0)
len(train_dataset),len(train_loader)#,len(test_dataset)
flag=0
model.eval()
ftmp_=[]
btmp_=[]
with torch.no_grad():
    for fe, label in train_loader:
        fe=fe.float()
        pred_prob_no_softmax,_=model(fe)
        ftmp_.append(_[:,:,0])
        btmp_.append(_[:,:,1])

In [None]:
#distribution of attention weights 
model.load_state_dict(torch.load('param.pkl'))
train_dtlist=[file for file in dtlist if (dtclean(file)==0) and file.split('/')[-1][0]=='N']
train_dataset = ApneaECGDataset(train_dtlist,istrain=False)
#test_dataset = ApneaECGDataset(test_dtlist,istrain=False)
train_loader=DataLoader(train_dataset, batch_size=256, shuffle=False, sampler=None, num_workers=0)
#test_loader=DataLoader(test_dataset, batch_size=256, shuffle=True, sampler=None, num_workers=0)
len(train_dataset),len(train_loader)#,len(test_dataset)
flag=0
model.eval()
fntmp_=[]
bntmp_=[]
with torch.no_grad():
    for fe, label in train_loader:
        fe=fe.float()
        pred_prob_no_softmax,_=model(fe)
        fntmp_.append(_[:,:,0])
        bntmp_.append(_[:,:,1])
an=fntmp_[0]
for tmppp in fntmp_[1::]:
    an=np.vstack((an,tmppp))
bn=bntmp_[0]
for tmppp in bntmp_[1::]:
    bn=np.vstack((bn,tmppp))
an.shape

In [None]:
#a=np.array(_[:,:,0].shape)
a=ftmp_[0]
for tmppp in ftmp_[1::]:
    a=np.vstack((a,tmppp))
b=btmp_[0]
for tmppp in btmp_[1::]:
    b=np.vstack((b,tmppp))
a.shape

In [None]:
#Apnea->all
#Figure 4
sns.set(style='darkgrid',font='Times New Roman')
cmap_=['Greens','hot','spring','summer','autumn','winter','ocean']
f=plt.figure(dpi=256)
f.subplots_adjust(hspace=0.09,wspace=0.05)
for i in range(6):
    plt.subplot(3,4,i+1)
    sns.kdeplot(a[:,i],b[:,i],shade=True,cmap='Greens',shade_lowest=False)
    plt.xlim(0.1,0.25)
    plt.ylim(0.1,0.25)
    plt.legend([i])
    if i in [0,1,2,3,4,5]:
        plt.xticks([])
    if i in [1,5,2,3,0,4]:
        plt.yticks([])
    if i==4:
        plt.ylabel('Backward weights')

for i in range(6):
    plt.subplot(3,4,i+1+6)
    sns.kdeplot(an[:,i],bn[:,i],shade=True,cmap='Blues',shade_lowest=False)
    plt.xlim(0.1,0.25)
    plt.ylim(0.1,0.25)
    plt.legend([i])
    if i in [0,1,3,2,4,5]:
        plt.xticks([])
    if i in [1,3,0,2,4,5]:
        plt.yticks([])

plt.savefig('Atten.pdf',dpi=256)
plt.show()
plt.close("all")

In [None]:
#MAC/value of model parameters
import torch
from torch import nn

class ResBlock1(nn.Module):

    def __init__(self, in_planes, out_planes):
        super(ResBlock1, self).__init__()
        self.conv1 = nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=1,padding=0)
        self.bn = nn.BatchNorm1d(out_planes)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(out_planes, out_planes, kernel_size=3, stride=1,padding=1)

        self.downsample = nn.Sequential(nn.AvgPool1d(kernel_size =2, stride = 1,padding=0),
                                        nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=1,padding=0))

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.conv2(out)
        
        out=torch.nn.functional.pad(out,(0,1))
        shortcut = self.downsample(x)

        return out + shortcut

class ResBlock2(nn.Module):

    def __init__(self, in_planes, out_planes):
        super(ResBlock2, self).__init__()
        
        self.bn1 = nn.BatchNorm1d(in_planes)
        self.relu1=nn.ReLU()
        self.dropout1 = nn.Dropout()
        self.conv1 = nn.Conv1d(in_planes, out_planes, kernel_size=3, stride= 1, padding=1)
        
        self.bn2 = nn.BatchNorm1d(out_planes)
        self.relu2=nn.ReLU()
        self.dropout2 = nn.Dropout()
        self.conv2 = nn.Conv1d(out_planes, out_planes, kernel_size=3, stride=1, padding=1)
        self.downsample = nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=1,padding=0)      

    def forward(self, x):
        #print(x.size())
        out = self.bn1(x)
        out = self.relu1(out)
        #out = self.dropout1(out)
        out = self.conv1(out)
        
        out = self.bn2(out)
        out = self.relu2(out)
        #out = self.dropout2(out)
        out = self.conv2(out)
        

        shortcut = self.downsample(x)
        #print(shortcut.shape)
        return out + shortcut


class RRNet(nn.Module):

    def __init__(self, num_classes):
        super(RRNet, self).__init__()
        self.Begin = [nn.Conv1d(1, 128, kernel_size=20, stride= 5, padding=0),
        nn.BatchNorm1d(128),
        nn.ReLU(),nn.AvgPool1d(kernel_size = 2, stride = 1)] 
        
        self.Res1 = [ResBlock1(128,64)]
        self.Res64 = [ResBlock2(64,64),ResBlock2(64,64)]
        self.Res128 = [ResBlock2(64,128),ResBlock2(128,128),ResBlock2(128,128),
                       ResBlock2(128,128)]
        self.Res256 = [ResBlock2(128,256),ResBlock2(256,256),ResBlock2(256,256),
                       ResBlock2(256,256),ResBlock2(256,256),ResBlock2(256,256)]
        self.Res512 = [ResBlock2(256,512),ResBlock2(512,512),ResBlock2(512,512)]
        #self.conv512 = nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=0)
        
        self.End = [nn.AvgPool1d(kernel_size=67,stride=1,padding=0)]
        self.Features = nn.Sequential(*self.Begin,*self.Res1,
                                      *self.Res64,*self.Res128,
                                      *self.Res256,*self.Res512,*self.End)
        self.Out = nn.Sequential(nn.Linear(512, num_classes))

    def forward(self, x):       
        Features = self.Features(x)
        out = self.Out(Features.view(Features.size(0),1,-1))
        out = out.view(out.size(0),-1)
        return out
# model=MyLSTM(2,100)
# inp=torch.FloatTensor(1,26,92)
model=RRNet(2)
inp=torch.FloatTensor(1,1,360)
from thop import profile
from thop import clever_format
with torch.no_grad():
    macs, params = profile(model, inputs=(inp, ))
print(macs, params)
macs, params = clever_format([macs, params], "%.3f")
print(macs, params)