# Logging

In [None]:
%env WANDB_SILENT=true
import wandb
wandb.login()

env: WANDB_SILENT=true


True

# Imports

In [None]:
from __future__ import annotations
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback
from fastxtend.vision.all import *
from tfrecord.reader import example_loader

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# AttentionPooling XResNet

In [None]:
def xeca_resnext50s(n_out=1000,  **kwargs):  
    return XResNet(ECAResNeXtBlock, 4, [3, 4, 9], block_szs=[64, 128, 256], n_out=n_out, **kwargs)

# Imagenette/Woof

In [None]:
imagewoof_stats =  ([0.496,0.461,0.399],[0.257,0.249,0.258])
imagenette_stats = ([0.465,0.458,0.429],[0.285,0.28,0.301])

def get_dls(size, woof, bs, sh=0., augs=None, workers=None, stats=True):
    if size<=224: path = URLs.IMAGEWOOF_320 if woof else URLs.IMAGENETTE_320
    else        : path = URLs.IMAGEWOOF     if woof else URLs.IMAGENETTE
    source = untar_data(path)
    if workers is None: workers = min(8, num_cpus())
    batch_tfms = []
    if stats:
        if woof: 
            batch_tfms += [Normalize.from_stats(*imagewoof_stats)]
        else:
            batch_tfms += [Normalize.from_stats(*imagenette_stats)]
    if augs: batch_tfms += augs
    if sh: batch_tfms.append(RandomErasing(p=0.3, max_count=3, sh=sh))
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                       splitter=GrandparentSplitter(valid_name='val'),
                       get_items=get_image_files, get_y=parent_label,
                       item_tfms=[RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)],
                       batch_tfms=batch_tfms)
    return dblock.dataloaders(source, bs=bs, num_workers=workers)

# Training Code

In [None]:
def train(run_name, model, lr, epochs=80, bs=64, size=224, loss_func=LabelSmoothingCrossEntropyFlat(), 
          batch_tfms=aug_transforms(max_zoom=1, max_rotate=20, xtra_tfms=[Hue(), Saturation()]), 
          woof=True, seeds=[42,314,1618], valid_sizes=[256,384], log=True):

    if log:
        run_results, broke = None, False
        run = wandb.init(project="AttentionPooling", name=f'{run_name} lr={lr}', 
                         group= f'ImageWoof {epochs}E' if woof else f'Imagenette {epochs}E')
    try:
        for seed in seeds:
            try:
                with less_random(seed):
                    dls = get_dls(size, woof, bs, augs=batch_tfms)

                cbs = [WandbCallback(log=None, log_preds=False, log_model=False)] if log else []
                with less_random(seed):
                    learn = Learner(dls, model(n_out=dls.c, act_cls=nn.Mish, stem_pool=MaxBlurPool, block_pool=BlurPool), 
                                    loss_func=loss_func, opt_func=ranger,
                                    metrics=[accuracy, MatthewsCorrCoef()], 
                                    cbs=[CutMixUpAugment()]+cbs).to_fp16()

                with less_random(seed):
                    learn.fit_flat_cos(epochs, lr) 
                    learn.save(f'{run_name} lr={lr} seed={seed}', with_opt=False)

                if log:
                    if run_results is None:
                        run_results = {n:[v] for n,v in zip(learn.recorder.metric_names, learn.recorder.log) if n not in ['epoch', 'time']}
                    else:
                        for n,v in zip(learn.recorder.metric_names, learn.recorder.log):
                            if n not in ['epoch', 'time']: run_results[n].append(v)

                if not is_listy(valid_sizes):
                    valid_sizes = [valid_sizes]

                if log: learn.remove_cb(WandbCallback)
                for vs in valid_sizes:
                    dls = None
                    learn.dls = get_dls(vs, woof, 48)
                    loss, acc, matthews = learn.validate()
                    if log:
                        for n,v in zip([f'loss_{vs}', f'accuracy_{vs}', f'matthews_{vs}'], [loss, acc, matthews]):
                            if n in run_results.keys(): run_results[n].append(v)
                            else: run_results[n] = [v]
                    else:
                        print(f'accuracy_{vs}: {acc}, matthews_{vs}: {matthews}')
            except CancelFitException:
                print('Terminate on NAN')
                broke = True
                raise
            finally:
                free_gpu_memory(learn, dls)
            
    finally:
        if log:
            if not broke:
                for n in run_results.keys():
                    run.summary[f'{n}_mean'] = np.mean(run_results[n])
                    run.summary[f'{n}_std'] = np.std(run_results[n])
            run.finish()

# XResNeXt50

In [None]:
train('XResNeXt50', xeca_resnext50, lr=8e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.275173,2.266385,0.175872,0.091372,00:36
1,2.207791,2.049352,0.29575,0.22445,00:35
2,2.073383,2.02637,0.348435,0.281999,00:36
3,1.945767,1.699834,0.468822,0.412319,00:35
4,1.886709,1.70785,0.441843,0.38498,00:35
5,1.789281,1.529827,0.540087,0.489451,00:35
6,1.715256,1.685355,0.48562,0.452684,00:35
7,1.630681,1.354193,0.630695,0.590785,00:35
8,1.614853,1.367146,0.618223,0.581715,00:35
9,1.523062,1.214201,0.709595,0.677902,00:35


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.268112,2.234637,0.206159,0.124996,00:35
1,2.175589,2.021521,0.31382,0.244447,00:34
2,2.077675,2.425421,0.235429,0.17856,00:35
3,1.951811,1.754004,0.435225,0.379809,00:35
4,1.838389,1.731614,0.455332,0.400508,00:35
5,1.753304,1.513417,0.559684,0.513808,00:35
6,1.681683,1.633003,0.512344,0.462692,00:35
7,1.636609,1.303074,0.664291,0.628317,00:35
8,1.560014,1.40187,0.619242,0.581453,00:35
9,1.537989,1.215009,0.692797,0.661941,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.262358,2.340616,0.204632,0.127643,00:35
1,2.150326,2.006909,0.334945,0.262031,00:35
2,1.995332,1.947241,0.332655,0.274761,00:35
3,1.904399,1.621464,0.504709,0.449185,00:35
4,1.83489,1.611827,0.513108,0.464535,00:35
5,1.728539,1.39751,0.618733,0.577242,00:35
6,1.669606,1.468317,0.57012,0.528711,00:35
7,1.573947,1.262516,0.675235,0.638876,00:35
8,1.569396,1.266263,0.675744,0.640919,00:35
9,1.469698,1.137706,0.734793,0.70523,00:36


# XResNeXt50S

In [None]:
train('XResNeXt50S', xeca_resnext50s, lr=8e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.210137,2.228553,0.216086,0.153088,00:34
1,2.050555,1.889555,0.39832,0.333262,00:34
2,1.966817,1.858322,0.38763,0.326757,00:34
3,1.817924,1.64075,0.497837,0.448233,00:34
4,1.759484,1.47617,0.579282,0.534623,00:34
5,1.650967,1.382793,0.608552,0.566855,00:34
6,1.594731,1.544204,0.561721,0.532921,00:34
7,1.513079,1.237,0.683635,0.650111,00:34
8,1.50676,1.379514,0.621787,0.587657,00:34
9,1.416615,1.084331,0.763553,0.73786,00:35


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.198302,3.292138,0.201323,0.12421,00:34
1,2.07371,1.882615,0.35887,0.290406,00:33
2,1.986684,2.066887,0.319165,0.263588,00:33
3,1.855697,1.600134,0.532196,0.4823,00:35
4,1.743419,1.531552,0.540087,0.491444,00:34
5,1.661834,1.34795,0.64393,0.605121,00:34
6,1.585903,1.49189,0.576737,0.536174,00:34
7,1.571841,1.236178,0.689488,0.655506,00:34
8,1.492612,1.363058,0.637821,0.602673,00:34
9,1.488464,1.14327,0.733265,0.704979,00:34


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.219316,2.21232,0.218885,0.143482,00:33
1,2.098676,1.943067,0.343599,0.276729,00:33
2,1.959928,1.846583,0.402647,0.343564,00:33
3,1.877982,1.606843,0.513871,0.462357,00:33
4,1.789014,1.677233,0.492492,0.44741,00:33
5,1.683933,1.330364,0.644439,0.605497,00:34
6,1.624085,1.569354,0.517689,0.475533,00:33
7,1.520611,1.192293,0.717231,0.686283,00:34
8,1.515788,1.298577,0.661237,0.626616,00:33
9,1.425708,1.128217,0.742683,0.71484,00:34


# Learned Aggregation

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class LearnedAggregation(nn.Module):
    "Learned Aggregation from https://arxiv.org/abs/2112.13692"
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.gamma_1 = nn.Parameter(1e-4 * torch.ones(ni))
        self.gamma_2 = nn.Parameter(1e-4 * torch.ones(ni))
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.cls_q + self.gamma_1 * self.attn(x, self.cls_q)
        return x + self.gamma_2 * self.ffn(self.norm(x))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def LearnedAggregationHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [LearnedAggregation(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), norm(ni), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
        head[0]._init_weights(head[2])
    return head

In [None]:
train('XResNeXt50S LearnedAggregation', partial(xeca_resnext50s, custom_head=LearnedAggregationHead), lr=9e-4)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.199715,2.074349,0.256554,0.189763,00:41
1,2.113541,1.915713,0.374396,0.308368,00:35
2,2.058154,1.923991,0.339781,0.284045,00:35
3,1.960853,1.740424,0.449224,0.390719,00:36
4,1.949185,1.69369,0.500382,0.445497,00:35
5,1.860092,1.70176,0.453296,0.398683,00:35
6,1.829159,1.631393,0.504454,0.455855,00:35
7,1.748034,1.646558,0.48282,0.434684,00:35
8,1.737887,1.436802,0.601425,0.560294,00:35
9,1.65019,1.510172,0.561466,0.516201,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.222728,2.111834,0.213031,0.137394,00:35
1,2.148643,1.971029,0.336472,0.264311,00:35
2,2.075427,1.94241,0.331128,0.269409,00:35
3,1.999005,1.752864,0.450496,0.392025,00:35
4,1.899158,1.705505,0.447951,0.392992,00:35
5,1.857024,1.603351,0.513108,0.460021,00:35
6,1.788565,1.557084,0.537287,0.487598,00:35
7,1.770495,1.571936,0.53576,0.487702,00:35
8,1.689481,1.41257,0.610842,0.567438,00:35
9,1.691272,1.447476,0.587936,0.544425,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.218343,2.085494,0.260626,0.187821,00:35
1,2.143909,1.963208,0.343599,0.274539,00:35
2,2.025316,1.942327,0.350471,0.289534,00:35
3,1.989517,1.723473,0.455841,0.398157,00:35
4,1.942484,1.714366,0.465513,0.410322,00:35
5,1.8653,1.64248,0.489692,0.442875,00:35
6,1.825867,1.566904,0.524815,0.476785,00:35
7,1.773232,1.669179,0.463731,0.410693,00:35
8,1.735302,1.389533,0.617205,0.575659,00:35
9,1.673775,1.4545,0.591499,0.547901,00:35


# Learned Aggregation Sandwich

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class LearnedAggregation(nn.Module):
    "Learned Aggregation from https://arxiv.org/abs/2112.13692"
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.gamma_1 = nn.Parameter(1e-4 * torch.ones(ni))
        self.gamma_2 = nn.Parameter(1e-4 * torch.ones(ni))
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm1 = norm(ni)
        self.norm2 = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.cls_q + self.gamma_1 * self.norm1(self.attn(x, self.cls_q))
        return x + self.gamma_2 * self.ffn(self.norm2(x))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def LearnedAggregationHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [LearnedAggregation(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), norm(ni), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
        head[0]._init_weights(head[2])
    return head

In [None]:
train('XResNeXt50S LearnedAggregation Sandwich', partial(xeca_resnext50s, custom_head=LearnedAggregationHead), lr=1e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.16293,2.036501,0.266989,0.200788,00:42
1,2.114321,1.953784,0.349962,0.284505,00:36
2,2.054463,1.898321,0.375923,0.318136,00:35
3,1.964612,1.775316,0.434971,0.374972,00:36
4,1.927525,1.677666,0.493255,0.437373,00:36
5,1.849304,1.710855,0.466276,0.410883,00:36
6,1.799846,1.600896,0.526343,0.480863,00:36
7,1.723859,1.513466,0.54874,0.501282,00:37
8,1.713125,1.391695,0.629677,0.588739,00:36
9,1.622702,1.43595,0.607534,0.565766,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.165343,2.080783,0.266989,0.189603,00:35
1,2.108735,1.910543,0.370323,0.303133,00:35
2,2.047782,1.883989,0.357852,0.294685,00:35
3,1.971045,1.861304,0.359634,0.305784,00:36
4,1.867591,1.622243,0.499364,0.445995,00:35
5,1.818781,1.564354,0.53576,0.485806,00:35
6,1.748495,1.498617,0.578519,0.533445,00:36
7,1.713022,1.473939,0.583864,0.539793,00:35
8,1.655845,1.364,0.640876,0.601707,00:36
9,1.639026,1.362305,0.631967,0.592114,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.172347,2.046285,0.272079,0.200696,00:35
1,2.130665,1.922118,0.342326,0.274302,00:35
2,2.008561,1.922264,0.356579,0.293546,00:36
3,1.98922,1.826526,0.382031,0.327139,00:35
4,1.924709,1.698948,0.485111,0.433171,00:36
5,1.834007,1.582563,0.533469,0.482967,00:36
6,1.796024,1.487679,0.577755,0.531508,00:35
7,1.721547,1.447338,0.581573,0.537745,00:36
8,1.694872,1.344667,0.631204,0.590602,00:36
9,1.62355,1.373183,0.622805,0.582666,00:37


# Learned Aggregation Sandwich Act

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class LearnedAggregation(nn.Module):
    "Learned Aggregation from https://arxiv.org/abs/2112.13692"
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.gamma_1 = nn.Parameter(1e-4 * torch.ones(ni))
        self.gamma_2 = nn.Parameter(1e-4 * torch.ones(ni))
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm1 = norm(ni)
        self.norm2 = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.cls_q + self.gamma_1 * self.norm1(self.attn(x, self.cls_q))
        return x + self.gamma_2 * self.ffn(self.norm2(x))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def LearnedAggregationHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, act_cls=nn.GELU, **kwargs):
    head = [LearnedAggregation(ni, norm=norm, ffn_expand=ffn_expand, act_cls=act_cls, **kwargs), norm(ni), act_cls(), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
        head[0]._init_weights(head[2])
    return head

In [None]:
train('XResNeXt50S LearnedAggregation Sandwich Act', partial(xeca_resnext50s, custom_head=LearnedAggregationHead), lr=1e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.164388,2.1107,0.230084,0.159082,00:36
1,2.075185,1.930684,0.37745,0.311404,00:36
2,2.001977,1.840582,0.39552,0.340996,00:35
3,1.882548,1.655813,0.505218,0.45103,00:36
4,1.866592,1.618268,0.50649,0.455723,00:35
5,1.782115,1.522692,0.547213,0.498728,00:35
6,1.738984,1.648432,0.481547,0.443717,00:35
7,1.676545,1.407655,0.608806,0.565444,00:36
8,1.660331,1.451688,0.580555,0.538561,00:35
9,1.578126,1.295559,0.661237,0.624191,00:37


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.152368,2.039074,0.272334,0.205681,00:36
1,2.082298,1.886147,0.389921,0.322833,00:35
2,2.010265,1.840331,0.372105,0.310159,00:35
3,1.933848,1.688125,0.485111,0.429255,00:36
4,1.836339,1.603956,0.495546,0.444773,00:35
5,1.789663,1.487964,0.573683,0.526817,00:35
6,1.725387,1.59335,0.517434,0.468626,00:36
7,1.694936,1.372732,0.635022,0.593956,00:36
8,1.6369,1.409501,0.610588,0.567275,00:36
9,1.641824,1.322783,0.644439,0.60663,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.157185,2.023455,0.290659,0.220023,00:35
1,2.075041,1.904002,0.383049,0.314378,00:35
2,1.941383,1.784949,0.423517,0.363365,00:35
3,1.885943,1.628339,0.508272,0.453367,00:35
4,1.83548,1.627799,0.511581,0.458851,00:36
5,1.761568,1.4451,0.590227,0.54552,00:35
6,1.712351,1.515052,0.555867,0.510802,00:35
7,1.638793,1.33212,0.647748,0.607796,00:36
8,1.629362,1.343239,0.627641,0.588509,00:36
9,1.555137,1.245339,0.681344,0.645436,00:35


# AvgAttnPooling2d Sandwich LA Gamma

In [None]:
class AttentionPool2d(nn.Module):
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.gamma_1 = nn.Parameter(1e-3 * torch.ones(ni))
        self.gamma_2 = nn.Parameter(1e-3 * torch.ones(ni))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.gamma_1*self.norm(self.pool(x).flatten(1) + self.attn(x, self.cls_q))
        return x + self.gamma_2*self.ffn(x)

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), norm(ni), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
        head[0]._init_weights(head[2])
    return head

In [None]:
train('XResNeXt50S AvgAttnPool Sandwich LAGamma', partial(xeca_resnext50s, custom_head=AvgAttnPoolHead), lr=2e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.218582,2.10168,0.27717,0.207716,00:36
1,2.087541,2.069672,0.313057,0.250398,00:35
2,2.030333,1.848468,0.435225,0.373977,00:35
3,1.936322,1.770031,0.433444,0.372778,00:36
4,1.89528,1.592759,0.52507,0.471701,00:35
5,1.803183,1.661347,0.497073,0.444504,00:35
6,1.755874,1.502527,0.568592,0.524554,00:35
7,1.662472,1.523921,0.53805,0.492324,00:36
8,1.64698,1.324765,0.666582,0.62991,00:35
9,1.552462,1.388726,0.628659,0.591207,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.219265,2.080118,0.272843,0.194331,00:35
1,2.118569,1.945393,0.333673,0.261179,00:36
2,2.038664,1.801727,0.438534,0.378347,00:35
3,1.942253,1.897742,0.370578,0.31119,00:35
4,1.863841,1.539714,0.550522,0.500741,00:35
5,1.782749,1.567509,0.52507,0.475703,00:35
6,1.702776,1.4172,0.611606,0.569672,00:36
7,1.69582,1.38782,0.630186,0.589794,00:36
8,1.603765,1.285112,0.677781,0.641759,00:36
9,1.594832,1.353362,0.63604,0.599204,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.217733,2.09387,0.282769,0.213296,00:35
1,2.106164,1.92576,0.348944,0.280284,00:35
2,1.976718,1.806582,0.423517,0.365468,00:35
3,1.934481,1.791562,0.40901,0.355784,00:35
4,1.880269,1.600621,0.536523,0.488765,00:35
5,1.792796,1.55708,0.526852,0.480891,00:36
6,1.741612,1.427618,0.611861,0.570035,00:35
7,1.674702,1.521971,0.530924,0.496919,00:36
8,1.640675,1.248966,0.683125,0.64836,00:36
9,1.565032,1.33278,0.640366,0.602562,00:36


# AvgAttnPooling2d Sandwich Act

In [None]:
class AttentionPool2d(nn.Module):
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.norm(self.pool(x).flatten(1) + self.attn(x, self.cls_q))
        return x + self.ffn(x)

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnPoolHead(ni, n_out, norm=nn.LayerNorm, act_cls=nn.GELU, ffn_expand=3, **kwargs):
    head = [AvgAttnPooling2d(ni, norm=norm, ffn_expand=ffn_expand, act_cls=act_cls, **kwargs), norm(ni), act_cls(), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
        head[0]._init_weights(head[2])
    return head

In [None]:
train('XResNeXt50S AvgAttnPool Sandwich Act', partial(xeca_resnext50s, custom_head=AvgAttnPoolHead), lr=1e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.14853,2.155521,0.227539,0.15921,00:35
1,2.050369,1.844701,0.412828,0.350112,00:35
2,1.993673,1.820871,0.400356,0.34111,00:35
3,1.867723,1.614109,0.51209,0.4585,00:36
4,1.866067,1.697676,0.459404,0.406174,00:35
5,1.793833,1.534654,0.538814,0.490873,00:35
6,1.759429,1.684922,0.472639,0.425858,00:35
7,1.69019,1.433376,0.589463,0.544225,00:35
8,1.686307,1.528879,0.531942,0.490082,00:35
9,1.617431,1.330491,0.652074,0.613941,00:38


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.141338,1.960067,0.341308,0.270957,00:35
1,2.056163,1.841487,0.39603,0.330095,00:35
2,1.995072,1.804457,0.423263,0.36479,00:35
3,1.920931,1.671572,0.483075,0.425493,00:35
4,1.838118,1.609304,0.496564,0.443167,00:35
5,1.799824,1.491811,0.573937,0.527052,00:35
6,1.739381,1.511634,0.55943,0.512385,00:35
7,1.720034,1.409274,0.611351,0.568529,00:35
8,1.662797,1.504433,0.570374,0.524885,00:35
9,1.663742,1.316419,0.649784,0.611799,00:35


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.172167,2.025127,0.315093,0.245049,00:35
1,2.081193,1.855959,0.389158,0.324772,00:35
2,1.956747,1.831199,0.400102,0.342507,00:35
3,1.90618,1.660453,0.486892,0.430269,00:35
4,1.865507,1.647197,0.505218,0.455562,00:35
5,1.804027,1.495214,0.578519,0.532386,00:35
6,1.763422,1.535518,0.554594,0.508779,00:35
7,1.683836,1.359977,0.638076,0.597669,00:35
8,1.670078,1.381175,0.599389,0.560187,00:35
9,1.605983,1.382951,0.622041,0.580478,00:35


# AvgAttnPooling2d SansFFN Act

In [None]:
class AttentionPool2d(nn.Module):
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm = norm(ni)
        self.act = act_cls()
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        return self.act(self.norm(self.pool(x).flatten(1) + self.attn(x, self.cls_q)))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
    return head

In [None]:
train('XResNeXt50S AvgAttnPool SansFFN Act', partial(xeca_resnext50s, custom_head=AvgAttnPoolHead), lr=2e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.217198,2.103383,0.2479,0.179801,00:35
1,2.086106,1.912519,0.390939,0.326242,00:35
2,2.01722,1.886144,0.361415,0.299658,00:35
3,1.898516,1.641482,0.515144,0.461954,00:35
4,1.881334,1.703461,0.456096,0.405346,00:35
5,1.779724,1.461578,0.576737,0.531642,00:35
6,1.7332,1.731334,0.447951,0.418224,00:35
7,1.654424,1.408283,0.605752,0.564248,00:35
8,1.644352,1.406972,0.602952,0.567805,00:34
9,1.557352,1.403816,0.600662,0.565524,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.189426,2.086656,0.256554,0.180384,00:35
1,2.091202,1.883134,0.369305,0.301477,00:35
2,2.016049,1.819536,0.41181,0.353154,00:35
3,1.939035,1.659163,0.496564,0.440824,00:35
4,1.840533,1.657263,0.466022,0.413258,00:35
5,1.79634,1.452765,0.59379,0.548624,00:35
6,1.723863,1.555904,0.554594,0.508021,00:35
7,1.699689,1.480451,0.570883,0.528178,00:35
8,1.618445,1.334449,0.646729,0.607399,00:35
9,1.610651,1.494324,0.552049,0.514274,00:35


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.169126,2.054002,0.275643,0.207698,00:35
1,2.080884,1.880108,0.386103,0.323197,00:35
2,1.982408,1.91783,0.34818,0.289814,00:35
3,1.949485,1.702753,0.47213,0.41436,00:34
4,1.882707,1.694783,0.476712,0.42577,00:35
5,1.795479,1.480819,0.583864,0.538133,00:35
6,1.737259,1.574161,0.51438,0.465883,00:35
7,1.661399,1.337026,0.640112,0.600854,00:35
8,1.648251,1.370919,0.610079,0.571646,00:35
9,1.54215,1.244661,0.676763,0.641425,00:35


# AvgAttnConcatPooling2d SansFFN Act

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnConcatPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm1 = norm(ni)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm2 = norm(ni)
        self.act = act_cls()
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        return self.act(torch.cat([self.norm2(self.pool(x).flatten(1)), self.norm1(self.attn(x, self.cls_q))], dim=1))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnConcatPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnConcatPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), nn.Linear(2*ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
    return head

In [None]:
train('XResNeXt50S AvgAttnConcatPool SansFFN Act', partial(xeca_resnext50s, custom_head=AvgAttnConcatPoolHead), lr=2e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.297112,2.215315,0.235174,0.157868,00:41
1,2.182778,2.145555,0.2535,0.176476,00:35
2,2.136963,1.967891,0.322728,0.251663,00:35
3,2.032156,1.955669,0.335454,0.265398,00:36
4,2.018785,1.779679,0.437516,0.373289,00:35
5,1.921312,1.79556,0.410792,0.349457,00:35
6,1.866847,1.79488,0.417664,0.35701,00:35
7,1.820113,1.818485,0.419445,0.361034,00:35
8,1.788138,1.500756,0.568592,0.521201,00:35
9,1.697274,1.604156,0.519471,0.473759,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.264998,2.120294,0.255281,0.172644,00:35
1,2.199141,2.172093,0.23492,0.149799,00:35
2,2.163475,2.021135,0.30364,0.225219,00:35
3,2.065572,2.250071,0.23721,0.171376,00:35
4,1.969288,1.806692,0.395011,0.332336,00:35
5,1.904979,1.96812,0.352252,0.289287,00:35
6,1.845357,1.657282,0.481038,0.42756,00:35
7,1.793278,1.661786,0.489947,0.43379,00:35
8,1.736913,1.51531,0.565029,0.518258,00:36
9,1.72295,1.557994,0.53296,0.488482,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.248329,2.152552,0.230848,0.147614,00:35
1,2.180704,2.027618,0.29804,0.221314,00:35
2,2.073351,2.043039,0.285823,0.213865,00:35
3,2.005577,1.801045,0.406974,0.341875,00:35
4,1.96151,1.811962,0.411555,0.348848,00:36
5,1.881781,1.644522,0.489692,0.436386,00:35
6,1.832667,1.616667,0.499109,0.447148,00:35
7,1.766659,1.703825,0.44846,0.397705,00:35
8,1.74663,1.516992,0.553831,0.505286,00:35
9,1.663023,1.631146,0.51438,0.464984,00:35


# AvgAttnConcatPooling2d Sandwich Act

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnConcatPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm1 = norm(ni)
        self.norm2 = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        self.norm3 = norm(ni)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm4 = norm(ni)
        self.act = act_cls()
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        a = self.cls_q + self.norm1(self.attn(x, self.cls_q))
        a = a + self.ffn(self.norm2(a))
        return self.act(torch.cat([self.norm4(self.pool(x).flatten(1)), self.norm3(a)], dim=1))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnConcatPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnConcatPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), nn.Linear(2*ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
    return head

In [None]:
train('XResNeXt50S AvgAttnConcatPool Sandwich Act', partial(xeca_resnext50s, custom_head=AvgAttnConcatPoolHead), lr=4e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.24707,2.312194,0.200305,0.122885,00:35
1,2.286287,2.243082,0.175617,0.085144,00:34
2,2.191408,2.126094,0.232375,0.161776,00:34
3,2.091385,1.962409,0.341563,0.269655,00:35
4,2.026589,1.94032,0.341563,0.274004,00:35
5,1.987536,1.887734,0.36167,0.313164,00:35
6,1.954597,1.934154,0.355561,0.291447,00:35
7,1.858371,1.638948,0.483838,0.42766,00:35
8,1.809539,1.726303,0.451514,0.401177,00:34
9,1.723825,1.440339,0.589718,0.544999,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.34271,2.217415,0.204123,0.114785,00:35
1,2.30108,2.200681,0.216849,0.127514,00:34
2,2.23748,2.198413,0.215067,0.133729,00:35
3,2.166969,2.117022,0.240519,0.170383,00:35
4,2.075769,1.983752,0.312293,0.240211,00:35
5,1.995202,2.148214,0.256808,0.194401,00:35
6,1.913547,1.779634,0.413082,0.353252,00:35
7,1.840586,1.603961,0.505981,0.454219,00:35
8,1.763736,1.630456,0.496564,0.447018,00:35
9,1.732377,1.493826,0.547468,0.502179,00:35


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.281902,2.40391,0.221176,0.150351,00:35
1,2.196093,2.236266,0.254009,0.181431,00:35
2,2.0767,1.93108,0.35098,0.279181,00:35
3,2.048316,2.007191,0.294986,0.231796,00:34
4,1.958012,1.690471,0.478493,0.4212,00:35
5,1.851131,1.782991,0.422245,0.37132,00:35
6,1.786196,1.508677,0.554849,0.508484,00:34
7,1.696515,1.599271,0.510308,0.463812,00:35
8,1.646349,1.332712,0.644693,0.606096,00:35
9,1.55048,1.407001,0.599898,0.561184,00:35


# XResNeXt50S 4e-3

In [None]:
train('XResNeXt50S', xeca_resnext50s, lr=4e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.199578,2.177327,0.252736,0.183022,00:41
1,2.036237,1.970947,0.370069,0.308797,00:34
2,1.968839,1.822005,0.412319,0.355031,00:33
3,1.850082,1.861875,0.411555,0.366236,00:34
4,1.795483,1.606907,0.52787,0.477833,00:34
5,1.690547,1.397515,0.610842,0.568029,00:34
6,1.62262,1.410447,0.606516,0.569714,00:35
7,1.535878,1.237257,0.688725,0.655927,00:35
8,1.528444,1.274081,0.674981,0.642484,00:33
9,1.439838,1.099151,0.764062,0.737665,00:34


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.208771,2.89967,0.168491,0.107973,00:34
1,2.07792,1.890151,0.37185,0.303591,00:33
2,1.989523,1.858901,0.392466,0.332761,00:33
3,1.892927,1.701194,0.476966,0.421419,00:34
4,1.77662,1.597322,0.518198,0.466353,00:34
5,1.686146,1.379799,0.626877,0.587186,00:34
6,1.604083,1.460429,0.587681,0.546244,00:34
7,1.57895,1.23662,0.695851,0.662233,00:34
8,1.50339,1.322603,0.650038,0.615629,00:34
9,1.493593,1.145677,0.731229,0.702591,00:34


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.206701,2.147917,0.241792,0.167819,00:33
1,2.078779,1.883552,0.386103,0.321354,00:33
2,1.959669,1.940098,0.363197,0.305265,00:34
3,1.89638,1.65378,0.5014,0.44565,00:33
4,1.828534,1.619388,0.518453,0.470818,00:34
5,1.706048,1.468858,0.590481,0.548437,00:34
6,1.645401,1.611683,0.50649,0.465473,00:33
7,1.544696,1.219901,0.704505,0.671802,00:34
8,1.527019,1.275575,0.670654,0.636468,00:34
9,1.439009,1.135573,0.741665,0.714127,00:34


# XResNeXt50S 2e-3

In [None]:
train('XResNeXt50S', xeca_resnext50s, lr=2e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.214926,2.178648,0.239756,0.15838,00:34
1,2.086955,1.972897,0.354543,0.287404,00:33
2,1.993013,1.846488,0.39603,0.332417,00:33
3,1.890417,1.733408,0.463477,0.408544,00:34
4,1.854775,1.754693,0.45406,0.40015,00:34
5,1.770636,1.537141,0.552558,0.503488,00:33
6,1.711036,1.642688,0.50649,0.464346,00:33
7,1.630091,1.354583,0.632985,0.594055,00:34
8,1.607546,1.413861,0.606007,0.568416,00:34
9,1.523908,1.231358,0.697124,0.664636,00:35


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.211693,2.12376,0.253754,0.172806,00:34
1,2.101108,1.940102,0.348689,0.278938,00:33
2,2.00939,1.833033,0.394502,0.332723,00:33
3,1.921067,1.751356,0.455332,0.396659,00:34
4,1.841918,1.805307,0.445915,0.389347,00:34
5,1.785612,1.556543,0.550522,0.503999,00:34
6,1.711014,1.667698,0.501145,0.456627,00:34
7,1.661788,1.380684,0.626114,0.585223,00:34
8,1.58841,1.385006,0.62026,0.580533,00:34
9,1.571578,1.328068,0.635022,0.597947,00:34


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.218417,2.177433,0.229575,0.152096,00:33
1,2.095515,1.953498,0.336218,0.266409,00:33
2,1.976661,1.930134,0.368033,0.305164,00:34
3,1.922175,1.717,0.47493,0.41924,00:33
4,1.871294,1.67382,0.494019,0.439326,00:34
5,1.781707,1.555937,0.546704,0.498758,00:34
6,1.722283,1.614509,0.505218,0.459826,00:34
7,1.632742,1.328306,0.65182,0.613657,00:34
8,1.606763,1.337647,0.642148,0.604711,00:34
9,1.523794,1.257299,0.683635,0.650562,00:34
