# Single-Channel ResNeSt for SED architecture
---

This notebook contains code for modifying ResNeSt models for single-channel use. This permits their inclusion as the encoder in a modified PANNs SED architecture. This method was used as part of my submission to the RFCX audio classifier competition.


I am indebted to Ryan Epp for his [excellent notebook](https://www.kaggle.com/reppic/mean-teachers-find-more-birds) and to Hidehisa Arai for the [original PANNS SED notebook](https://www.kaggle.com/hidehisaarai1213/introduction-to-sound-event-detection) on Kaggle.

In [None]:
!pip install torchlibrosa > /dev/null
!pip install git+https://github.com/zhanghang1989/ResNeSt.git > /dev/null

### Params

In [None]:
RESNEST_TYPE = '50'

RESIZE_DICT = {'50' : 224, 
               '101' : 256, 
               '200' : 320}


N_CLASSES = 24
N_CHANNELS = 1
RESIZE = RESIZE_DICT[RESNEST_TYPE]
ENCODER_LEN = 2048
DROPOUT = 0.5


FFT = 4096
HOP = 512
F_MIN = 60
F_MAX = 14000
SR = 36000

### Libraries

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from resnest.torch import resnest50, resnest101, resnest200
from resnest.torch.resnet import ResNet, Bottleneck

# Single-channel pretrained ResNeSt feature extractor
---

This model removes the average-pooling, flatten and fully-connected layers from the original ResNeSt model. This allows it to be used as a feature extractor, since the existing code does not have an inbuilt function for this. If you wish to use it as a full CNN model, uncomment the lines in `forward()`.

Additionally, the initial convolutional layer is modified to allow it to take single-channel image input.

All credit to the original author *zhanghang1989* at github: https://github.com/zhanghang1989/ResNeSt 

In [None]:
class ResNestEncoder(ResNet):
    def __init__(self):
        super(ResNestEncoder, self).__init__(Bottleneck, [3, 4, 6, 3], radix=2, groups=1,
                                          bottleneck_width=64, deep_stem=True, stem_width=32,
                                          avg_down=True, avd=True, avd_first=False)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        '''
        # if using full model
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        if self.drop:
            x = self.drop(x)
        x = self.fc(x)
        '''
        
        return x       
    

def get_model():    
    model = ResNestEncoder()
    model.load_state_dict(torch.hub.load_state_dict_from_url('https://s3.us-west-1.wasabisys.com/resnest/torch/resnest50-528c19ca.pth',
                                                             progress=True, check_hash=True))
    
    # modify initial convolutional layer to use a single channel
    model.conv1 = nn.Sequential(
    nn.Conv2d(N_CHANNELS, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False), #<-- in_channels specified here
    nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
    nn.ReLU(inplace=True),
    nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
    nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
    nn.ReLU(inplace=True),
    nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )        
                
    return model

resnest_feature_extractor = get_model()

We can test it by passing a single-channel image tensor:

In [None]:
a = torch.rand([2, 1, 64, 64]) # a batch of 2 single-channel images of size 64px by 64px
resnest_feature_extractor(a)

# PANNs SED Architecture using ResNeSt
---
The only pertinent change is that size of `ENCODER_FEATURES` is different to that of other CNNs. The current value works for `resnest50` - you may have to adjust it for other variants.

In [None]:
def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)


def init_bn(bn):
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.0)

def interpolate(x: torch.Tensor, ratio: int):
    """Interpolate data in time domain. This is used to compensate the
    resolution reduction in downsampling of a CNN.

    Args:
      x: (batch_size, time_steps, classes_num)
      ratio: int, ratio to interpolate
    Returns:
      upsampled: (batch_size, time_steps * ratio, classes_num)
    """
    (batch_size, time_steps, classes_num) = x.shape
    upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
    upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
    return upsampled

class AttentionHead(nn.Module):
    
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.conv_attention = nn.Conv1d(in_channels=in_features, 
                                        out_channels=out_features,
                                        kernel_size=1, stride=1, 
                                        padding=0, bias=True)
        self.conv_classes = nn.Conv1d(in_channels=in_features, 
                                      out_channels=out_features,
                                      kernel_size=1, stride=1, 
                                      padding=0, bias=True)
        self.batch_norm_attention = nn.BatchNorm1d(out_features)
        self.init_weights()

    def init_weights(self):
        init_layer(self.conv_attention)
        init_layer(self.conv_classes)
        init_bn(self.batch_norm_attention)

    def forward(self, x):
        norm_att = torch.softmax(torch.tanh(self.conv_attention(x)), dim=-1)
        classes = self.conv_classes(x)
        x = torch.sum(norm_att * classes, dim=2)
        return x, norm_att, classes


class SEDAudioClassifier(nn.Module):
    def __init__(self, sample_rate=SR, n_fft=FFT, hop_length=HOP, 
                 mel_bins=RESIZE, fmin=F_MIN, fmax=F_MAX,
                 n_classes=N_CLASSES, dropout=DROPOUT):
        super().__init__()
        self.interpolate_ratio = 32

        self.spectrogram_extractor = Spectrogram(n_fft=n_fft, 
                                                 hop_length=hop_length,
                                                 win_length=None, 
                                                 window='hann',
                                                 center=True,
                                                 pad_mode='reflect', 
                                                 freeze_parameters=True)
        self.logmel_extractor = LogmelFilterBank(sr=sample_rate,
                                                 n_fft=n_fft,
                                                 n_mels=mel_bins,
                                                 fmin=fmin, 
                                                 fmax=fmax,
                                                 ref=1.0, 
                                                 amin=1e-10,
                                                 top_db=None, 
                                                 freeze_parameters=True)

        self.batch_norm = nn.BatchNorm2d(mel_bins)
        self.encoder = get_model()
        self.fc = nn.Linear(ENCODER_LEN, 
                            ENCODER_LEN, bias=True)
        self.att_head = AttentionHead(ENCODER_LEN, n_classes)
        self.avg_pool = nn.modules.pooling.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(dropout)
        self.init_weight()

    def init_weight(self):
        init_bn(self.batch_norm)
        init_layer(self.fc)
        self.att_head.init_weights()

    def forward(self, input, spec_aug=False, 
                mixup_lambda=None, return_encoding=False):
        x = self.spectrogram_extractor(input.float())
        x = self.logmel_extractor(x)
        
        x = x.transpose(1, 3)
        x = self.batch_norm(x)
        x = x.transpose(1, 3)
        x = self.encoder(x)
        x = torch.mean(x, dim=3)
        x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
        x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
        x = x1 + x2

        x = self.dropout(x)
        x = x.transpose(1, 2)
        x = F.relu_(self.fc(x))
        x = x.transpose(1, 2)
        x = self.dropout(x)

        (clipwise_output, norm_att, segmentwise_output) = self.att_head(x)
        segmentwise_output = segmentwise_output.transpose(1, 2)
        framewise_output = interpolate(segmentwise_output, self.interpolate_ratio)
        
        return clipwise_output, framewise_output    
        

We can test this by passing it an example tensor that is the same shape as a batch of audio vectors. This is the input that PANNs SED expects:

In [None]:
model = SEDAudioClassifier()
a = torch.rand([2, 25600]) # a batch of 2 flat audio tensors
model(a)

Success!