# Network structure

The idea here is to attempt to replicate the approach of the winners of the 2018 competition. The technical report is here:
http://dcase.community/documents/challenge2018/technical_reports/DCASE2018_Jeong_102.pdf


Their implementation here:
https://github.com/finejuly/dcase2018_task2_cochlearai



### Architecture
<img src="images/high_level_arch.png" width="25%"/><img src="images/block_arch.png" width="70%"/>

where the SE architecture is the Squeeze and Exitation block described here: https://arxiv.org/abs/1709.01507
<img src="images/se_arch.png" width="70%"/>


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class SEDenseLayer(nn.Module):
    def __init__(self, nf_in, nf_add):
        self.nf_in, self.nf_add = nf_in, nf_add
        super().__init__()
        self.dense_layers=nn.Sequential(
            nn.BatchNorm2d(nf_in),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=nf_in, out_channels=nf_in, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(nf_in),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=nf_in, out_channels=nf_add, kernel_size=3, stride=1, padding=1, bias=False),
        )
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        
        self.se_layers = nn.Sequential(
            nn.Linear(nf_add, nf_add//2, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(nf_add//2, nf_add, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        d = self.dense_layers(x)

        b, f, _, _ = d.size()
        se = self.avg_pool(d).view(b,f)
        se = self.se_layers(se).view(b,f,1,1)
        se = d * se.expand_as(d)
        
        return torch.cat([x, se], 1)

class SEDenseNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.first_conv = nn.Conv2d(in_channels=1, out_channels=15, kernel_size=3, stride=1, padding=1, bias=False)
        
        self.se_dense_layers = nn.Sequential(
            SEDenseLayer(16,16),
            nn.MaxPool2d(2),
            SEDenseLayer(32,32),
            nn.MaxPool2d(2),
            SEDenseLayer(64,64),
            nn.MaxPool2d(2),
            SEDenseLayer(128,128),
            nn.MaxPool2d(2),
            SEDenseLayer(256,256),
            nn.MaxPool2d(2),
            SEDenseLayer(512,512),
            nn.MaxPool2d(2),
            SEDenseLayer(512+512,512),
            nn.MaxPool2d(2),
            SEDenseLayer(512+512+512,512),
            nn.MaxPool2d(2)
            
        )
        
        self.linears = nn.ModuleList([
            nn.Linear(2048,80),
            nn.Linear(2048,80),
            nn.Linear(2048,80),
            nn.Linear(2048,80),
            nn.Linear(2048,80),
            nn.Linear(2048,80),
            nn.Linear(2048,80),
            nn.Linear(2048,80)
        ])
        
    
    def forward(self, x):
        
        y = torch.cat([x,self.first_conv(x)],1)
        
        d = self.se_dense_layers(y).squeeze(dim=3).squeeze(dim=2)
        
        linear_outs = []
        for l in self.linears:
            linear_outs.append(l(d))
            
        mean = torch.mean(torch.stack(linear_outs),dim=0)
        
        return mean