In [1]:
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import os
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
device = "cuda"

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1292e613af0>

In [2]:
class Physionet(Dataset):
    def __init__(self, root, train="train"):
        self.audios = torch.tensor(np.load(os.path.join(root, train + "_audios.npy"))).cuda()
        self.labels = torch.tensor(np.load(os.path.join(root, train + "_labels.npy")))
        self.labels = self.labels[:, ].cuda()
        # self.labels = torch.argmax(self.labels, dim=1)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        return self.audios[index], self.labels[index]

In [3]:
data_path = "..\\tydata\\audio"

train_dataset = Physionet(data_path, train="train")
val_dataset = Physionet(data_path, train="val")
test_dataset = Physionet(data_path, train="test")

print(f"Train: {len(train_dataset)}")
print(f"Val: {len(val_dataset)}")
print(f"Test: {len(test_dataset)}")

Train: 70235
Val: 8779
Test: 8780


In [4]:
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
import math

# 首先，在函数get_freq_indices(method)中，根据给定的method参数，确定了需要选择的频率索引。method参数可以是以下几种取值：
# 'top1','top2','top4','top8','top16','top32','bot1','bot2','bot4','bot8','bot16','bot32','low1','low2','low4','low8','low16','low32'。
# 根据不同的取值，函数会返回相应数量的频率索引列表（mapper_x和mapper_y）。
def get_freq_indices_l(method):
    assert method in ['top1', 'top2', 'top4', 'top8', 'top16', 'top32',
                      'bot1', 'bot2', 'bot4', 'bot8', 'bot16', 'bot32',
                      'low1', 'low2', 'low4', 'low8', 'low16', 'low32']
    num_freq = int(method[3:])  # 当method为16时，num_freq=16
    if 'top' in method:
        all_top_indices_x = [0, 0, 2, 6, 0, 0, 0, 0, 4, 6, 6, 6, 1, 5, 6, 5, 3, 3, 5, 6, 2, 0, 5, 1, 4, 3, 6, 4, 5, 4,
                             4, 3]
        all_top_indices_y = [1, 6, 3, 0, 4, 2, 0, 5, 0, 6, 1, 2, 2, 0, 3, 1, 3, 0, 2, 4, 0, 3, 5, 5, 4, 6, 5, 2, 6, 1,
                             5, 4]
        mapper_x = all_top_indices_x[:num_freq]
        mapper_y = all_top_indices_y[:num_freq]
    elif 'low' in method:
        all_low_indices_x = [0, 0, 1, 1, 0, 2, 2, 1, 2, 0, 3, 4, 0, 1, 3, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6, 1, 2,
                             3, 4]
        all_low_indices_y = [0, 1, 0, 1, 2, 0, 1, 2, 2, 3, 0, 0, 4, 3, 1, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5,
                             4, 3]
        mapper_x = all_low_indices_x[:num_freq]
        mapper_y = all_low_indices_y[:num_freq]
    elif 'bot' in method:
        all_bot_indices_x = [6, 1, 3, 3, 2, 4, 1, 2, 4, 4, 5, 1, 4, 6, 2, 5, 6, 1, 6, 2, 2, 4, 3, 3, 5, 5, 6, 2, 5, 5,
                             3, 6]
        all_bot_indices_y = [6, 4, 4, 6, 6, 3, 1, 4, 4, 5, 6, 5, 2, 2, 5, 1, 4, 3, 5, 0, 3, 1, 1, 2, 4, 2, 1, 1, 5, 3,
                             3, 3]
        mapper_x = all_bot_indices_x[:num_freq]
        mapper_y = all_bot_indices_y[:num_freq]
    else:
        raise NotImplementedError

    return mapper_x, mapper_y


def get_freq_indices_r(method):
    assert method in ['top1', 'top2', 'top4', 'top8', 'top16', 'top32',
                      'bot1', 'bot2', 'bot4', 'bot8', 'bot16', 'bot32',
                      'low1', 'low2', 'low4', 'low8', 'low16', 'low32']
    num_freq = int(method[3:])  # 当method为16时，num_freq=16
    if 'top' in method:
        all_top_indices_x = [0, 2, 6, 0, 0, 0, 3, 2, 5, 2, 1, 1, 0, 0, 3, 1, 3, 4, 5, 6, 2, 3, 6, 5, 6, 1, 6, 4, 4, 4,
                             4, 1]
        all_top_indices_y = [3, 0, 6, 4, 1, 0, 1, 1, 1, 6, 5, 0, 6, 2, 3, 3, 0, 2, 0, 0, 2, 2, 3, 2, 1, 4, 4, 0, 4, 5,
                             6, 6]
        mapper_x = all_top_indices_x[:num_freq]
        mapper_y = all_top_indices_y[:num_freq]
    elif 'low' in method:
        all_low_indices_x = [0, 0, 1, 1, 0, 2, 2, 1, 2, 0, 3, 4, 0, 1, 3, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6, 1, 2,
                             3, 4]
        all_low_indices_y = [0, 1, 0, 1, 2, 0, 1, 2, 2, 3, 0, 0, 4, 3, 1, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5,
                             4, 3]
        mapper_x = all_low_indices_x[:num_freq]
        mapper_y = all_low_indices_y[:num_freq]
    elif 'bot' in method:
        all_bot_indices_x = [6, 1, 3, 3, 2, 4, 1, 2, 4, 4, 5, 1, 4, 6, 2, 5, 6, 1, 6, 2, 2, 4, 3, 3, 5, 5, 6, 2, 5, 5,
                             3, 6]
        all_bot_indices_y = [6, 4, 4, 6, 6, 3, 1, 4, 4, 5, 6, 5, 2, 2, 5, 1, 4, 3, 5, 0, 3, 1, 1, 2, 4, 2, 1, 1, 5, 3,
                             3, 3]
        mapper_x = all_bot_indices_x[:num_freq]
        mapper_y = all_bot_indices_y[:num_freq]
    else:
        raise NotImplementedError

    return mapper_x, mapper_y


# 多个频率+通道+空间注意力机制
class MultiSpectralChannelSpatialAttentionLayer(torch.nn.Module):
    def __init__(self, channel, dct_h, dct_w, reduction=16, freq_sel_method='top8', side='l'):
        super(MultiSpectralChannelSpatialAttentionLayer, self).__init__()
        self.reduction = reduction
        self.dct_h = dct_h
        self.dct_w = dct_w

        if side == 'l':
            mapper_x, mapper_y = get_freq_indices_l(freq_sel_method)
        else:
            mapper_x, mapper_y = get_freq_indices_r(freq_sel_method)
        self.num_split = len(mapper_x)
        mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x]
        mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y]
        # make the frequencies in different sizes are identical to a 7x7 frequency space
        # eg, (2,2) in 14x14 is identical to (1,1) in 7x7

        # frequency
        self.dct_layer = MultiSpectralDCTLayer(dct_h, dct_w, mapper_x, mapper_y, channel)
        # channel
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )
        # spatial
        self.SpatialAttention = SpatialAttention()

    def forward(self, x):
        n, c, h, w = x.shape
        x_pooled = x
        if h != self.dct_h or w != self.dct_w:
            x_pooled = torch.nn.functional.adaptive_avg_pool2d(x, (self.dct_h, self.dct_w))
            # If you have concerns about one-line-change, don't worry.   :)
            # In the ImageNet models, this line will never be triggered. 
            # This is for compatibility in instance segmentation and object detection.
        y = self.dct_layer(x_pooled)

        y = self.fc(y).view(n, c, 1, 1)
        y = x * y.expand_as(x)

        y = self.SpatialAttention(y)

        return y


class MultiSpectralDCTLayer(nn.Module):
    """
    Generate dct filters
    """

    def __init__(self, height, width, mapper_x, mapper_y, channel):
        super(MultiSpectralDCTLayer, self).__init__()

        assert len(mapper_x) == len(mapper_y)
        # print(f"x {channel}  weight {len(mapper_x)}")
        # assert channel % len(mapper_x) == 0

        self.num_freq = len(mapper_x)

        # fixed DCT init
        self.register_buffer('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))

        # fixed random init
        # self.register_buffer('weight', torch.rand(channel, height, width))

        # learnable DCT init
        # self.register_parameter('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))

        # learnable random init
        # self.register_parameter('weight', torch.rand(channel, height, width))

        # num_freq, h, w

    def forward(self, x):
        assert len(x.shape) == 4, 'x must been 4 dimensions, but got ' + str(len(x.shape))
        # n, c, h, w = x.shape
        # print(f"x :{x.shape}  weight: {self.weight.shape}")
        x = x * self.weight

        result = torch.sum(x, dim=[2, 3])
        return result

    def build_filter(self, pos, freq, POS):
        result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS)
        if freq == 0:
            return result
        else:
            return result * math.sqrt(2)

    def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel):
        dct_filter = torch.zeros(channel, tile_size_x, tile_size_y)

        c_part = channel // len(mapper_x)

        for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)):
            for t_x in range(tile_size_x):
                for t_y in range(tile_size_y):
                    dct_filter[i * c_part: (i + 1) * c_part, t_x, t_y] = self.build_filter(t_x, u_x,
                                                                                           tile_size_x) * self.build_filter(
                        t_y, v_y, tile_size_y)

        return dct_filter


# 通道注意力
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction_ratio),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction_ratio, in_channels)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * self.sigmoid(y)


# 空间注意力
class SpatialAttention(nn.Module):
    def __init__(self):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        max_pool = torch.max(x, dim=1, keepdim=True)[0]
        avg_pool = torch.mean(x, dim=1, keepdim=True)
        y = torch.cat([max_pool, avg_pool], dim=1)
        y = self.conv(y)
        return x * self.sigmoid(y)


# 通道+空间注意力 CBAM
class CBAM(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(in_channels, reduction_ratio)
        self.spatial_attention = SpatialAttention()

    def forward(self, x):
        x = self.channel_attention(x)
        x = self.spatial_attention(x)
        return x

In [6]:
import librosa

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )
        self.relu1 = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu2(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )
        self.relu1 = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=True)
        self.relu3 = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.relu2(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = self.relu3(out)
        return out


class ZHJNet(nn.Module):
    def __init__(self):
        super().__init__()

        global _mapper_x, _mapper_y  #ty
        c2wh_l = [(36, 96), (18, 48), (18, 48), (9, 24), (9, 24), (9, 24)]  #ty big left
        c2wh_r = [(36, 96), (36, 96), (18, 48), (18, 48), (18, 48), (9, 24)]  #ty small right
        c_l = [16, 64, 128, 256, 512, 512]
        c_r = [16, 16, 32, 64, 128, 512]

        self.in_channels = 16

        self.filterbank_l = torch.tensor(librosa.filters.mel(
            sr=1000,
            n_fft=1024,
            n_mels=40,
            fmin=0.0,
            fmax=None,
            htk=False,
            norm='slaney',
        ).T).cuda()
        self.mask_l = (self.filterbank_l == 0).float().unsqueeze(0).cuda()
        self.inv_mask_l = 1 - self.mask_l
        self.filterbank_non_trainable_l = self.filterbank_l.clone().detach()
        self.filterbank_non_trainable_l.requires_grad = False
        self.filterbank_l = nn.Parameter(self.filterbank_l)
        print("left",self.filterbank_l.device, self.mask_l.device)
        
        self.filterbank_r = torch.tensor(librosa.filters.mel(
            sr=1000,
            n_fft=1024,
            n_mels=40,
            fmin=0.0,
            fmax=None,
            htk=False,
            norm='slaney',
        ).T).cuda()
        self.mask_r = (self.filterbank_r == 0).float().unsqueeze(0).cuda()
        self.inv_mask_r = 1 - self.mask_r
        self.filterbank_non_trainable_r = self.filterbank_r.clone().detach()
        self.filterbank_non_trainable_r.requires_grad = False
        self.filterbank_r = nn.Parameter(self.filterbank_r)
        print("right",self.filterbank_r.device, self.mask_r.device)
        
        self.conv_l_0 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=7, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )

        attention_num = 0
        self.conv_l_1 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            BasicBlock(16, 16, 1),
            BasicBlock(16, 16, 1),
            MultiSpectralChannelSpatialAttentionLayer(c_l[attention_num], c2wh_l[attention_num][0],
                                                      c2wh_l[attention_num][1], side='l')  #ty
        )

        attention_num = 1
        self.conv_l_2 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            Bottleneck(16, 16, 2),
            MultiSpectralChannelSpatialAttentionLayer(c_l[attention_num], c2wh_l[attention_num][0],
                                                      c2wh_l[attention_num][1], side='l')  #ty
        )

        attention_num = 2
        self.conv_l_3 = nn.Sequential(
            Bottleneck(64, 32, 1),
            Bottleneck(128, 32, 1),
            MultiSpectralChannelSpatialAttentionLayer(c_l[attention_num], c2wh_l[attention_num][0],
                                                      c2wh_l[attention_num][1], side='l')  #ty
        )

        attention_num = 3
        self.conv_l_4 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            Bottleneck(128, 64, 2),
            MultiSpectralChannelSpatialAttentionLayer(c_l[attention_num], c2wh_l[attention_num][0],
                                                      c2wh_l[attention_num][1], side='l')  #ty
        )

        attention_num = 4
        self.conv_l_5 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            Bottleneck(256, 64, 1),
            nn.Conv2d(256, 256, kernel_size=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            Bottleneck(256, 128, 1),
            Bottleneck(512, 128, 1),
            MultiSpectralChannelSpatialAttentionLayer(c_l[attention_num], c2wh_l[attention_num][0],
                                                      c2wh_l[attention_num][1], side='l')  #ty
        )

        attention_num = 5
        self.conv_l_6 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            Bottleneck(512, 128, 1),
            Bottleneck(512, 128, 1),
            MultiSpectralChannelSpatialAttentionLayer(c_l[attention_num], c2wh_l[attention_num][0],
                                                      c2wh_l[attention_num][1], side='l')  #ty
        )

        self.conv_l_12 = nn.Conv2d(16, 64, kernel_size=1)
        self.conv_l_12_3x3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv_l_23 = nn.Conv2d(64, 128, kernel_size=1)
        self.conv_l_23_1x1 = nn.Conv2d(128, 128, kernel_size=1)
        self.conv_l_34 = nn.Conv2d(128, 256, kernel_size=1)
        self.conv_l_34_3x3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv_l_45 = nn.Conv2d(256, 512, kernel_size=1)
        self.conv_l_45_1x1 = nn.Conv2d(512, 512, kernel_size=1)
        self.conv_l_56_1x1 = nn.Conv2d(512, 512, kernel_size=1)

        self.bn_l_12 = nn.BatchNorm2d(64)
        self.bn_l_23 = nn.BatchNorm2d(128)
        self.bn_l_34 = nn.BatchNorm2d(256)
        self.bn_l_45 = nn.BatchNorm2d(512)
        self.bn_l_56 = nn.BatchNorm2d(512)

        self.conv_r_0 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=7, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )

        attention_num = 0
        self.conv_r_1 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            BasicBlock(16, 16, 1),
            MultiSpectralChannelSpatialAttentionLayer(c_r[attention_num], c2wh_r[attention_num][0],
                                                      c2wh_r[attention_num][1], side='r')  #ty
        )

        attention_num = 1
        self.conv_r_2 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            BasicBlock(16, 16, 1),
            MultiSpectralChannelSpatialAttentionLayer(c_r[attention_num], c2wh_r[attention_num][0],
                                                      c2wh_r[attention_num][1], side='r')  #ty
        )

        attention_num = 2
        self.conv_r_3 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            BasicBlock(16, 16, 1),
            Bottleneck(16, 8, 2),
            MultiSpectralChannelSpatialAttentionLayer(c_r[attention_num], c2wh_r[attention_num][0],
                                                      c2wh_r[attention_num][1], side='r')  #ty
        )

        attention_num = 3
        self.conv_r_4 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            BasicBlock(32, 64, 1),
            MultiSpectralChannelSpatialAttentionLayer(c_r[attention_num], c2wh_r[attention_num][0],
                                                      c2wh_r[attention_num][1], side='r')  #ty
        )

        attention_num = 4
        self.conv_r_5 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            Bottleneck(64, 16, 1),
            Bottleneck(64, 16, 1),
            Bottleneck(64, 32, 1),
            MultiSpectralChannelSpatialAttentionLayer(c_r[attention_num], c2wh_r[attention_num][0],
                                                      c2wh_r[attention_num][1], side='r'),  #ty
        )

        attention_num = 5
        self.conv_r_6 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            Bottleneck(128, 128, 2),
            MultiSpectralChannelSpatialAttentionLayer(c_r[attention_num], c2wh_r[attention_num][0],
                                                      c2wh_r[attention_num][1], side='r')  #ty
        )

        self.conv_r_12_1x1 = nn.Conv2d(16, 16, kernel_size=1)
        self.conv_r_23 = nn.Conv2d(16, 32, kernel_size=1)
        self.conv_r_23_3x3 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.conv_r_34 = nn.Conv2d(32, 64, kernel_size=1)
        self.conv_r_34_1x1 = nn.Conv2d(64, 64, kernel_size=1)
        self.conv_r_45 = nn.Conv2d(64, 128, kernel_size=1)
        self.conv_r_45_1x1 = nn.Conv2d(128, 128, kernel_size=1)
        self.conv_r_56 = nn.Conv2d(128, 512, kernel_size=1)
        self.conv_r_56_3x3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)

        self.bn_r_12 = nn.BatchNorm2d(16)
        self.bn_r_23 = nn.BatchNorm2d(32)
        self.bn_r_34 = nn.BatchNorm2d(64)
        self.bn_r_45 = nn.BatchNorm2d(128)
        self.bn_r_56 = nn.BatchNorm2d(512)

        self.conv_m_12 = nn.Conv2d(16, 64, kernel_size=1)
        self.conv_m_12_1x1 = nn.Conv2d(64, 64, kernel_size=1)
        self.conv_m_23 = nn.Conv2d(32, 128, kernel_size=1)
        self.conv_m_23_3x3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv_m_34 = nn.Conv2d(64, 256, kernel_size=1)
        self.conv_m_34_1x1 = nn.Conv2d(256, 256, kernel_size=1)
        self.conv_m_45 = nn.Conv2d(128, 512, kernel_size=1)
        self.conv_m_45_3x3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv_m_56_3x3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)

        self.bn_m_12 = nn.BatchNorm2d(64)
        self.bn_m_23 = nn.BatchNorm2d(128)
        self.bn_m_34 = nn.BatchNorm2d(256)
        self.bn_m_45 = nn.BatchNorm2d(512)
        self.bn_m_56 = nn.BatchNorm2d(512)

        self.conv_f_12 = nn.Conv2d(64, 128, kernel_size=1)
        self.conv_f_23 = nn.Conv2d(128, 256, kernel_size=1, stride=2)
        self.conv_f_34 = nn.Conv2d(256, 512, kernel_size=1)
        self.bn_f_123 = nn.BatchNorm2d(128)
        self.bn_f_1234 = nn.BatchNorm2d(256)
        self.bn_f_12345 = nn.BatchNorm2d(512)
        self.bn_f_123456 = nn.BatchNorm2d(512)
        self.conv_f_123_1x1 = nn.Conv2d(128, 128, kernel_size=1)
        self.conv_f_1234_1x1 = nn.Conv2d(256, 256, kernel_size=1)
        self.conv_f_12345_1x1 = nn.Conv2d(512, 512, kernel_size=1)
        self.conv_f_123456_1x1 = nn.Conv2d(512, 512, kernel_size=1)

        self.decision_layer = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Dropout(0.1),
            nn.Linear(512, 2),
        )

    def _make_layer(self, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(Bottleneck(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * Bottleneck.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        # left
        waveform_to_stft = torch.stft(x, n_fft=1024, win_length=50, hop_length=25,
                                      window=torch.hann_window(50).cuda(), return_complex=True)
        stft_to_stftm = torch.transpose(torch.abs(waveform_to_stft), 1, 2).cuda()
        non_trainable_part = self.filterbank_non_trainable_l * self.mask_l
        trainable_part = self.filterbank_l * self.inv_mask_l
        stftm_to_melgram = torch.matmul(stft_to_stftm, trainable_part + non_trainable_part)
        x_l = stftm_to_melgram[:,:100,:].unsqueeze(1)

        x_l_0 = self.conv_l_0(x_l)
        x_l_1 = self.conv_l_1(x_l_0)
        x_l_2 = self.conv_l_2(x_l_1)
        x_l_3 = self.conv_l_3(x_l_2)
        x_l_4 = self.conv_l_4(x_l_3)
        x_l_5 = self.conv_l_5(x_l_4)
        x_l_6 = self.conv_l_6(x_l_5)

        x_l_12 = F.relu(self.bn_l_12(
            self.conv_l_12_3x3(self.conv_l_12(x_l_1) + F.interpolate(x_l_2, scale_factor=2, mode='nearest'))),
            inplace=True)
        x_l_23 = F.relu(self.bn_l_23(self.conv_l_23_1x1(self.conv_l_23(x_l_2) + x_l_3)), inplace=True)
        x_l_34 = F.relu(self.bn_l_34(
            self.conv_l_34_3x3(self.conv_l_34(x_l_3) + F.interpolate(x_l_4, scale_factor=2, mode='nearest'))),
            inplace=True)
        x_l_45 = F.relu(self.bn_l_45(self.conv_l_45_1x1(self.conv_l_45(x_l_4) + x_l_5)), inplace=True)
        x_l_56 = F.relu(self.bn_l_56(self.conv_l_56_1x1(x_l_5 + x_l_6)), inplace=True)
        
        # right
        waveform_to_stft = torch.stft(x, n_fft=1024, win_length=50, hop_length=25,
                                      window=torch.hann_window(50).cuda(), return_complex=True)
        stft_to_stftm = torch.transpose(torch.abs(waveform_to_stft), 1, 2).cuda()
        non_trainable_part = self.filterbank_non_trainable_r * self.mask_r
        trainable_part = self.filterbank_r * self.inv_mask_r
        stftm_to_melgram = torch.matmul(stft_to_stftm, trainable_part + non_trainable_part)
        x_r = stftm_to_melgram[:,:100,:].unsqueeze(1)

        x_r_0 = self.conv_r_0(x_r)
        x_r_1 = self.conv_r_1(x_r_0)
        x_r_2 = self.conv_r_2(x_r_1)
        x_r_3 = self.conv_r_3(x_r_2)
        x_r_4 = self.conv_r_4(x_r_3)
        x_r_5 = self.conv_r_5(x_r_4)
        x_r_6 = self.conv_r_6(x_r_5)

        x_r_12 = F.relu(self.bn_r_12(self.conv_r_12_1x1(x_r_1 + x_r_2)), inplace=True)
        x_r_23 = F.relu(self.bn_r_23(
            self.conv_r_23_3x3(self.conv_r_23(x_r_2) + F.interpolate(x_r_3, scale_factor=2, mode='nearest'))),
            inplace=True)
        x_r_34 = F.relu(self.bn_r_34(self.conv_r_34_1x1(self.conv_r_34(x_r_3) + x_r_4)), inplace=True)
        x_r_45 = F.relu(self.bn_r_45(self.conv_r_45_1x1(self.conv_r_45(x_r_4) + x_r_5)), inplace=True)
        x_r_56 = F.relu(self.bn_r_56(
            self.conv_r_56_3x3(self.conv_r_56(x_r_5) + F.interpolate(x_r_6, scale_factor=2, mode='nearest'))),
            inplace=True)
        
        # fusion
        x_m_12 = F.relu(self.bn_m_12(self.conv_m_12_1x1(x_l_12 + self.conv_m_12(x_r_12))), inplace=True)
        x_m_23 = F.relu(self.bn_m_23(
            self.conv_m_23_3x3(F.interpolate(x_l_23, scale_factor=2, mode='nearest') + self.conv_m_23(x_r_23))),
            inplace=True)
        x_m_34 = F.relu(self.bn_m_34(self.conv_m_34_1x1(x_l_34 + self.conv_m_34(x_r_34))), inplace=True)
        x_m_45 = F.relu(self.bn_m_45(
            self.conv_m_45_3x3(F.interpolate(x_l_45, scale_factor=2, mode='nearest') + self.conv_m_45(x_r_45))),
            inplace=True)
        x_m_56 = F.relu(
            self.bn_m_56(self.conv_m_56_3x3(F.interpolate(x_l_56, scale_factor=2, mode='nearest') + x_r_56)),
            inplace=True)

        x_f_123 = F.relu(self.bn_f_123(self.conv_f_123_1x1(self.conv_f_12(x_m_12) + x_m_23)), inplace=True)
        x_f_1234 = F.relu(self.bn_f_1234(self.conv_f_1234_1x1(self.conv_f_23(x_f_123) + x_m_34)), inplace=True)
        x_f_12345 = F.relu(self.bn_f_12345(self.conv_f_12345_1x1(self.conv_f_34(x_f_1234) + x_m_45)), inplace=True)
        x_f_123456 = F.relu(self.bn_f_123456(self.conv_f_123456_1x1((x_f_12345 + x_m_56))), inplace=True)
        x= self.decision_layer(x_f_123456)

        return x
    

In [7]:
from torchinfo import summary

summary(ZHJNet().cuda(), input_size=(1, 2540))

left cuda:0 cuda:0
right cuda:0 cuda:0


Layer (type:depth-idx)                                       Output Shape              Param #
ZHJNet                                                       [1, 2]                    41,040
├─Sequential: 1-1                                            [1, 16, 96, 36]           --
│    └─Conv2d: 2-1                                           [1, 16, 96, 36]           784
│    └─BatchNorm2d: 2-2                                      [1, 16, 96, 36]           32
│    └─ReLU: 2-3                                             [1, 16, 96, 36]           --
├─Sequential: 1-2                                            [1, 16, 96, 36]           --
│    └─Conv2d: 2-4                                           [1, 16, 96, 36]           256
│    └─BatchNorm2d: 2-5                                      [1, 16, 96, 36]           32
│    └─ReLU: 2-6                                             [1, 16, 96, 36]           --
│    └─BasicBlock: 2-7                                       [1, 16, 96, 36]           --

In [8]:
model = ZHJNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

left cuda:0 cuda:0
right cuda:0 cuda:0


In [11]:
from tqdm import tqdm

num_epochs = 200
start_epoch = 0
Min_Loss = 1000000
MAX_UAR = 0
logs = []
total_train = len(train_dataset)

for epoch in range(0, num_epochs):
    model.train()
    train_loss = 0.0
    for inputs, labels in tqdm(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * inputs.size(0)

    model.eval()
    correct_train = 0
    correct_test = 0
    total_test = 0

    TP = 0
    FN = 0
    TN = 0
    FP = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            _, labels_max = torch.max(labels, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels_max).sum().item()

            TN += ((predicted == 1) & (labels_max == 1)).sum().item()
            FN += ((predicted == 1) & (labels_max == 0)).sum().item()
            TP += ((predicted == 0) & (labels_max == 0)).sum().item()
            FP += ((predicted == 0) & (labels_max == 1)).sum().item()

    Se = TP / (TP + FN) if (TP + FN) > 0 else 0
    Sp = TN / (TN + FP) if (TN + FP) > 0 else 0
    Pr = TP / (TP + FP) if (TP + FP) > 0 else 0
    Acc = (TP + TN) / (TP + FP + TN + FN) if (TP + FP + TN + FN) > 0 else 0
    UAR = (Se + Sp) / 2
    F1 = (2 * Pr * Se) / (Pr + Se) if (Pr + Se) > 0 else 0
    logs.append([epoch, train_loss / total_train, Se, Sp, Pr, F1, Acc, UAR])
    print(f"Epoch [{epoch + 1}/{num_epochs}], "
          f"Train Loss: {(train_loss / total_train):.7f}, "
          f"TRAIN-Acc: {(correct_train / total_train) * 100:.4f}% "
          f"VAL--Acc: {(correct_test / total_test) * 100:.4f}% "
          f"Se:{Se * 100:.4f} "
          f"Sp:{Sp * 100:.4f} "
          f"Pr:{Pr * 100:.4f} "
          f"F1: {F1 * 100:.4f}% "
          f"Acc: {Acc * 100:.4f}% "
          f"UAR: {UAR * 100:.4f}% ")
    torch.save({
            'model_state_dict': model.state_dict(),
        }, f'results/TwoMouses/model_checkpoint_TwoMouses_{epoch}.pth')

print('Training Finished')

import csv

# 打开文件并将数组写入CSV文件
with open('results/TwoMouses/model_TwoMouses.csv', mode='w',
          newline='') as file:
    writer = csv.writer(file)
    writer.writerows(logs)

  0%|          | 1/1098 [00:03<55:02,  3.01s/it]


KeyboardInterrupt: 

In [8]:
import librosa
model = ZHJNet().cuda()
checkpoint = torch.load(
    'model_checkpoint_TwoMouses_141.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

correct_test = 0
total_test = 0
TP = 0
FN = 0
TN = 0
FP = 0
ori_labels = []
pred_labels = []
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        _, labels_max = torch.max(labels, 1)
        pred_labels.append(predicted)
        ori_labels.append(labels_max)
        total_test += labels.size(0)
        correct_test += (predicted == labels_max).sum().item()

        TN += ((predicted == 1) & (labels_max == 1)).sum().item()
        FN += ((predicted == 1) & (labels_max == 0)).sum().item()
        TP += ((predicted == 0) & (labels_max == 0)).sum().item()
        FP += ((predicted == 0) & (labels_max == 1)).sum().item()

Se = TP / (TP + FN) if (TP + FN) > 0 else 0
Sp = TN / (TN + FP) if (TN + FP) > 0 else 0
Pr = TP / (TP + FP) if (TP + FP) > 0 else 0
Acc = (TP + TN) / (TP + FP + TN + FN) if (TP + FP + TN + FN) > 0 else 0
UAR = (Se + Sp) / 2
F1 = (2 * Pr * Se) / (Pr + Se) if (Pr + Se) > 0 else 0

print(f"TEST-Acc: {(correct_test / total_test) * 100:.4f}% "
      f"Se:{Se * 100:.4f} "
      f"Sp:{Sp * 100:.4f} "
      f"Pr:{Pr * 100:.4f} "
      f"F1: {F1 * 100:.4f}% "
      f"Acc: {Acc * 100:.4f}% "
      f"UAR: {UAR * 100:.4f}% ")

left cuda:0 cuda:0
right cuda:0 cuda:0
TEST-Acc: 96.2187% Se:92.7072 Sp:97.1923 Pr:90.1531 F1: 91.4123% Acc: 96.2187% UAR: 94.9498% 
