# Logging

In [None]:
%env WANDB_SILENT=true
import wandb
wandb.login()

env: WANDB_SILENT=true


True

# Imports

In [None]:
from __future__ import annotations
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback
from fastxtend.vision.all import *
from tfrecord.reader import example_loader

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# AttentionPooling XResNet

In [None]:
def xeca_resnext50s(n_out=1000,  **kwargs):  
    return XResNet(ECAResNeXtBlock, 4, [3, 4, 9], block_szs=[64, 128, 256], n_out=n_out, **kwargs)

# Imagenette/Woof

In [None]:
imagewoof_stats =  ([0.496,0.461,0.399],[0.257,0.249,0.258])
imagenette_stats = ([0.465,0.458,0.429],[0.285,0.28,0.301])

def get_dls(size, woof, bs, sh=0., augs=None, workers=None, stats=True):
    if size<=224: path = URLs.IMAGEWOOF_320 if woof else URLs.IMAGENETTE_320
    else        : path = URLs.IMAGEWOOF     if woof else URLs.IMAGENETTE
    source = untar_data(path)
    if workers is None: workers = min(8, num_cpus())
    batch_tfms = []
    if stats:
        if woof: 
            batch_tfms += [Normalize.from_stats(*imagewoof_stats)]
        else:
            batch_tfms += [Normalize.from_stats(*imagenette_stats)]
    if augs: batch_tfms += augs
    if sh: batch_tfms.append(RandomErasing(p=0.3, max_count=3, sh=sh))
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                       splitter=GrandparentSplitter(valid_name='val'),
                       get_items=get_image_files, get_y=parent_label,
                       item_tfms=[RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)],
                       batch_tfms=batch_tfms)
    return dblock.dataloaders(source, bs=bs, num_workers=workers)

# Training Code

In [None]:
def train(run_name, model, lr, epochs=80, bs=64, size=224, loss_func=LabelSmoothingCrossEntropyFlat(), 
          batch_tfms=aug_transforms(max_zoom=1, max_rotate=20, xtra_tfms=[Hue(), Saturation()]), 
          woof=False, seeds=[42,314,1618], valid_sizes=[256,384], log=True):

    if log:
        run_results, broke = None, False
        run = wandb.init(project="AttentionPooling", name=f'{run_name} lr={lr}', 
                         group= f'ImageWoof {epochs}E' if woof else f'Imagenette {epochs}E')
    try:
        for seed in seeds:
            try:
                with less_random(seed):
                    dls = get_dls(size, woof, bs, augs=batch_tfms)

                cbs = [WandbCallback(log=None, log_preds=False, log_model=False)] if log else []
                with less_random(seed):
                    learn = Learner(dls, model(n_out=dls.c, act_cls=nn.Mish, stem_pool=MaxBlurPool, block_pool=BlurPool), 
                                    loss_func=loss_func, opt_func=ranger,
                                    metrics=[accuracy, MatthewsCorrCoef()], 
                                    cbs=[CutMixUpAugment(), TerminateOnTrainNaN()]+cbs).to_fp16()

                with less_random(seed):
                    learn.fit_flat_cos(epochs, lr) 
                    learn.save(f'{run_name} lr={lr} seed={seed}', with_opt=False)

                if log:
                    if run_results is None:
                        run_results = {n:[v] for n,v in zip(learn.recorder.metric_names, learn.recorder.log) if n not in ['epoch', 'time']}
                    else:
                        for n,v in zip(learn.recorder.metric_names, learn.recorder.log):
                            if n not in ['epoch', 'time']: run_results[n].append(v)

                if not is_listy(valid_sizes):
                    valid_sizes = [valid_sizes]

                if log: learn.remove_cb(WandbCallback)
                for vs in valid_sizes:
                    dls = None
                    learn.dls = get_dls(vs, woof, 48)
                    loss, acc, matthews = learn.validate()
                    if log:
                        for n,v in zip([f'loss_{vs}', f'accuracy_{vs}', f'matthews_{vs}'], [loss, acc, matthews]):
                            if n in run_results.keys(): run_results[n].append(v)
                            else: run_results[n] = [v]
                    else:
                        print(f'accuracy_{vs}: {acc}, matthews_{vs}: {matthews}')
            except CancelFitException:
                print('Terminate on NAN')
                broke = True
                raise
            finally:
                free_gpu_memory(learn, dls)
            
    finally:
        if log:
            if not broke:
                for n in run_results.keys():
                    run.summary[f'{n}_mean'] = np.mean(run_results[n])
                    run.summary[f'{n}_std'] = np.std(run_results[n])
            run.finish()

# XResNeXt50

In [None]:
train('XResNeXt50', xeca_resnext50, lr=8e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.931443,1.859167,0.461656,0.412816,00:43
1,1.782335,1.546854,0.561274,0.520042,00:36
2,1.641653,1.468011,0.590064,0.556021,00:36
3,1.510957,1.224061,0.703694,0.672715,00:36
4,1.456795,1.252888,0.69121,0.663444,00:35
5,1.383507,0.988383,0.810955,0.790108,00:36
6,1.313917,1.050969,0.78242,0.761917,00:36
7,1.284544,0.974421,0.81121,0.79293,00:35
8,1.230295,1.314478,0.642293,0.622429,00:36
9,1.211341,0.842395,0.874904,0.861195,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.050091,2.031096,0.408917,0.347432,00:36
1,1.814444,1.498833,0.612739,0.574571,00:36
2,1.613603,1.345989,0.658089,0.623293,00:36
3,1.501361,1.150921,0.74242,0.714327,00:36
4,1.396268,1.203587,0.709554,0.683366,00:36
5,1.38529,1.029264,0.797962,0.777399,00:36
6,1.319885,1.177109,0.726115,0.70441,00:35
7,1.265612,0.966612,0.822166,0.803965,00:37
8,1.275217,0.996361,0.80051,0.781405,00:36
9,1.219295,0.899976,0.852994,0.837623,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.945756,2.34968,0.390064,0.337925,00:36
1,1.734989,1.440277,0.605605,0.565456,00:36
2,1.58738,1.694418,0.54293,0.505344,00:36
3,1.486477,1.12333,0.750064,0.723527,00:36
4,1.408277,1.186446,0.727389,0.701239,00:36
5,1.352255,0.983016,0.814013,0.793971,00:36
6,1.323621,0.999514,0.804586,0.784543,00:36
7,1.260409,0.915581,0.842548,0.825673,00:36
8,1.248298,0.973432,0.810955,0.791978,00:36
9,1.237463,0.8717,0.857834,0.842921,00:36


# XResNeXt50S

In [None]:
train('XResNeXt50S', xeca_resnext50s, lr=8e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.873333,1.787193,0.444841,0.400915,00:41
1,1.684915,1.350767,0.664713,0.629839,00:35
2,1.562775,1.365247,0.63949,0.6058,00:36
3,1.458607,1.144725,0.752866,0.726847,00:36
4,1.405803,1.120221,0.75414,0.729852,00:35
5,1.352149,0.980994,0.812229,0.791986,00:36
6,1.281261,1.073795,0.770191,0.748318,00:36
7,1.257544,0.951602,0.819363,0.80209,00:36
8,1.204146,1.047317,0.774013,0.753901,00:37
9,1.176287,0.844125,0.867261,0.852851,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.927342,1.708915,0.502675,0.452693,00:36
1,1.69675,1.409351,0.634904,0.601711,00:36
2,1.533466,1.383173,0.629045,0.596742,00:36
3,1.439683,1.073894,0.776306,0.751955,00:36
4,1.351007,1.12171,0.755159,0.73064,00:36
5,1.346685,1.034178,0.792866,0.773228,00:36
6,1.282375,1.412918,0.624713,0.613662,00:36
7,1.233703,0.90455,0.852484,0.836924,00:37
8,1.247338,1.02438,0.793376,0.77428,00:36
9,1.191223,0.876627,0.855541,0.840332,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.933242,1.635115,0.520255,0.469895,00:37
1,1.697675,1.396199,0.631847,0.593151,00:36
2,1.551829,1.62743,0.578853,0.542854,00:37
3,1.462276,1.079256,0.77656,0.752243,00:37
4,1.379537,1.063922,0.77656,0.754018,00:36
5,1.331353,0.963519,0.822675,0.803392,00:36
6,1.300989,0.999546,0.803312,0.783663,00:36
7,1.237689,0.915753,0.841274,0.824237,00:36
8,1.222376,0.925327,0.829809,0.813141,00:37
9,1.21632,0.848152,0.869554,0.855423,00:37


# Learned Aggregation

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class LearnedAggregation(nn.Module):
    "Learned Aggregation from https://arxiv.org/abs/2112.13692"
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.gamma_1 = nn.Parameter(1e-4 * torch.ones(ni))
        self.gamma_2 = nn.Parameter(1e-4 * torch.ones(ni))
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.cls_q + self.gamma_1 * self.attn(x, self.cls_q)
        return x + self.gamma_2 * self.ffn(self.norm(x))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def LearnedAggregationHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [LearnedAggregation(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), norm(ni), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
        head[0]._init_weights(head[2])
    return head

In [None]:
train('XResNeXt50S LearnedAggregation', partial(xeca_resnext50s, custom_head=LearnedAggregationHead), lr=2e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.011165,1.818747,0.407134,0.373827,00:42
1,1.831233,1.609386,0.528662,0.484669,00:36
2,1.675683,1.338111,0.645096,0.608979,00:36
3,1.551039,1.313098,0.662166,0.629161,00:36
4,1.501497,1.12369,0.747771,0.72164,00:36
5,1.454054,1.281678,0.661401,0.6335,00:36
6,1.383226,1.064597,0.779618,0.756707,00:37
7,1.341928,1.171533,0.72,0.696274,00:37
8,1.297376,1.009955,0.789045,0.768117,00:37
9,1.269322,1.082096,0.748535,0.726388,00:37


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.007233,2.02511,0.338599,0.280297,00:37
1,1.814006,1.60603,0.523057,0.481228,00:37
2,1.647912,1.278171,0.694777,0.664264,00:37
3,1.550788,1.284282,0.676178,0.643137,00:37
4,1.444968,1.128719,0.750828,0.724279,00:37
5,1.433881,1.169968,0.727389,0.70002,00:37
6,1.376617,1.053587,0.77121,0.748888,00:37
7,1.335432,1.115497,0.741147,0.716077,00:37
8,1.352441,0.988169,0.804586,0.784155,00:37
9,1.294061,1.038528,0.784968,0.76428,00:37


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.996875,1.760824,0.448662,0.400233,00:37
1,1.820866,1.855295,0.395159,0.366724,00:37
2,1.657393,1.30942,0.673376,0.640015,00:37
3,1.579724,1.456973,0.596178,0.561727,00:37
4,1.506467,1.203226,0.703185,0.677236,00:37
5,1.443111,1.210192,0.709554,0.682114,00:37
6,1.409171,1.113854,0.750064,0.724626,00:37
7,1.33428,1.129116,0.745987,0.723107,00:37
8,1.320784,0.98078,0.805096,0.784301,00:37
9,1.324352,1.013504,0.793885,0.772275,00:37


# Learned Aggregation Sandwich

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class LearnedAggregation(nn.Module):
    "Learned Aggregation from https://arxiv.org/abs/2112.13692"
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.gamma_1 = nn.Parameter(1e-4 * torch.ones(ni))
        self.gamma_2 = nn.Parameter(1e-4 * torch.ones(ni))
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm1 = norm(ni)
        self.norm2 = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.cls_q + self.gamma_1 * self.norm1(self.attn(x, self.cls_q))
        return x + self.gamma_2 * self.ffn(self.norm2(x))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def LearnedAggregationHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [LearnedAggregation(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), norm(ni), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
        head[0]._init_weights(head[2])
    return head

In [None]:
train('XResNeXt50S LearnedAggregation Sandwich', partial(xeca_resnext50s, custom_head=LearnedAggregationHead), lr=2e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.935766,1.771872,0.458089,0.410543,00:41
1,1.78841,1.647751,0.496306,0.44971,00:36
2,1.638774,1.264864,0.68051,0.648243,00:36
3,1.553275,1.295342,0.686879,0.654775,00:36
4,1.497666,1.129743,0.746242,0.719835,00:36
5,1.449815,1.34029,0.644331,0.622218,00:36
6,1.378704,1.044096,0.792866,0.771127,00:36
7,1.342261,1.155363,0.71949,0.696003,00:36
8,1.294699,1.038282,0.775796,0.754543,00:37
9,1.270609,1.104202,0.750573,0.729099,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.944162,1.761731,0.462675,0.410153,00:36
1,1.782831,1.582641,0.54828,0.506267,00:36
2,1.613864,1.251722,0.701911,0.671117,00:36
3,1.522429,1.21347,0.698599,0.668614,00:36
4,1.437742,1.14895,0.738599,0.711853,00:36
5,1.426211,1.161238,0.73121,0.705035,00:36
6,1.385213,1.075443,0.75949,0.737523,00:36
7,1.328368,1.052744,0.778089,0.755764,00:37
8,1.348121,0.961096,0.818599,0.79927,00:37
9,1.281214,0.963532,0.818599,0.79972,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.961038,1.735453,0.489936,0.437371,00:36
1,1.765373,1.617671,0.482293,0.452903,00:36
2,1.636871,1.276058,0.679236,0.644477,00:36
3,1.565638,1.334234,0.650955,0.619969,00:36
4,1.498705,1.141668,0.741656,0.714786,00:36
5,1.43726,1.246488,0.682293,0.652453,00:36
6,1.398155,1.06107,0.777834,0.753982,00:36
7,1.333715,1.096422,0.764841,0.742615,00:36
8,1.320538,0.952946,0.817834,0.79859,00:36
9,1.313921,0.979203,0.804841,0.784545,00:36


# Learned Aggregation Sandwich Act

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class LearnedAggregation(nn.Module):
    "Learned Aggregation from https://arxiv.org/abs/2112.13692"
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.gamma_1 = nn.Parameter(1e-4 * torch.ones(ni))
        self.gamma_2 = nn.Parameter(1e-4 * torch.ones(ni))
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm1 = norm(ni)
        self.norm2 = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.cls_q + self.gamma_1 * self.norm1(self.attn(x, self.cls_q))
        return x + self.gamma_2 * self.ffn(self.norm2(x))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def LearnedAggregationHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, act_cls=nn.GELU, **kwargs):
    head = [LearnedAggregation(ni, norm=norm, ffn_expand=ffn_expand, act_cls=act_cls, **kwargs), norm(ni), act_cls(), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
        head[0]._init_weights(head[2])
    return head

In [None]:
train('XResNeXt50S LearnedAggregation Sandwich Act', partial(xeca_resnext50s, custom_head=LearnedAggregationHead), lr=2e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.229613,2.204889,0.17409,0.084549,00:41
1,2.246479,2.208562,0.178417,0.092067,00:35
2,2.250402,2.249667,0.163655,0.0722,00:35
3,2.261558,2.226099,0.181471,0.095533,00:36
4,2.25456,2.203577,0.193688,0.104323,00:35
5,2.253506,2.247835,0.155256,0.064909,00:35
6,2.259998,2.23839,0.166455,0.076953,00:35
7,2.27332,2.255743,0.14762,0.062141,00:35
8,2.259762,2.234837,0.165946,0.072855,00:35
9,2.262636,2.23427,0.173072,0.086978,00:35


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.215165,2.161765,0.208195,0.128661,00:35
1,2.23332,2.187429,0.201578,0.117724,00:35
2,2.246288,2.219075,0.200305,0.114562,00:35
3,2.241808,2.200167,0.185543,0.092498,00:35
4,2.246395,2.201646,0.186307,0.101782,00:35
5,2.247957,2.209857,0.186561,0.099982,00:35
6,2.243659,2.201751,0.190125,0.100759,00:35
7,2.242734,2.203024,0.185034,0.09386,00:35
8,2.24637,2.203598,0.177144,0.090158,00:35
9,2.241934,2.241209,0.165182,0.070367,00:35


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.255433,2.20486,0.200305,0.113573,00:35
1,2.273941,2.253045,0.173581,0.081152,00:35
2,2.250644,2.226319,0.156019,0.067027,00:35
3,2.239831,2.201669,0.186561,0.096896,00:35
4,2.245977,2.191149,0.189616,0.106662,00:35
5,2.230433,2.202088,0.18478,0.093041,00:35
6,2.25346,2.234155,0.153474,0.058431,00:36
7,2.242507,2.197605,0.188598,0.100141,00:36
8,2.245353,2.196074,0.188343,0.099576,00:35
9,2.233477,2.210443,0.186561,0.098634,00:35


# AvgAttnPooling2d Sandwich

In [None]:
class AttentionPool2d(nn.Module):
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.norm(self.pool(x).flatten(1) + self.attn(x, self.cls_q))
        return x + self.ffn(x)

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), norm(ni), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
        head[0]._init_weights(head[2])
    return head

In [None]:
train('XResNeXt50S AvgAttnPool Sandwich', partial(xeca_resnext50s, custom_head=AvgAttnPoolHead), lr=1e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.223667,2.187624,0.22703,0.159284,00:41
1,2.13506,2.104096,0.306948,0.24402,00:36
2,2.051466,1.902437,0.339272,0.285779,00:35
3,1.954682,1.842627,0.418427,0.356344,00:36
4,1.940699,1.681597,0.494528,0.442941,00:36
5,1.853244,1.69809,0.447442,0.392633,00:35
6,1.811858,1.584385,0.537032,0.488373,00:35
7,1.744685,1.600598,0.527615,0.479266,00:36
8,1.73286,1.409729,0.610588,0.567717,00:35
9,1.639355,1.466274,0.587936,0.546283,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.201554,2.034327,0.286587,0.217652,00:35
1,2.10803,1.86976,0.384322,0.318101,00:35
2,2.016401,1.868092,0.365742,0.30688,00:35
3,1.937798,1.83253,0.388649,0.329356,00:36
4,1.850792,1.638455,0.482566,0.431498,00:36
5,1.804706,1.520677,0.555612,0.506636,00:35
6,1.746489,1.511188,0.550522,0.501666,00:36
7,1.719568,1.473834,0.585136,0.542697,00:35
8,1.655676,1.405414,0.625859,0.584971,00:36
9,1.648287,1.417066,0.605498,0.565291,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.232194,2.159578,0.229829,0.161656,00:36
1,2.143567,2.082816,0.296768,0.229582,00:35
2,2.013696,1.817149,0.412319,0.3513,00:36
3,1.969599,1.906336,0.407228,0.351207,00:35
4,1.913609,1.689427,0.495291,0.442811,00:36
5,1.851156,1.634236,0.504963,0.45299,00:36
6,1.804379,1.546008,0.542377,0.49373,00:36
7,1.741615,1.634086,0.5014,0.454915,00:36
8,1.719329,1.402781,0.592517,0.551784,00:36
9,1.642217,1.40923,0.608552,0.566302,00:36


# AvgAttnPooling2d Sandwich Gamma

In [None]:
class AttentionPool2d(nn.Module):
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.gamma_1 = nn.Parameter(1e-3 * torch.ones(ni))
        self.gamma_2 = nn.Parameter(1e-3 * torch.ones(ni))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.gamma_1*self.norm(self.pool(x).flatten(1) + self.attn(x, self.cls_q))
        return x + self.gamma_2*self.ffn(x)

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), norm(ni), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
        head[0]._init_weights(head[2])
    return head

In [None]:
train('XResNeXt50S AvgAttnPool Sandwich Gamma', partial(xeca_resnext50s, custom_head=AvgAttnPoolHead), lr=3e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.992674,1.639891,0.530701,0.482395,00:42
1,1.766993,1.839444,0.385987,0.345476,00:36
2,1.65237,1.384601,0.619108,0.583147,00:37
3,1.57284,1.282285,0.682803,0.649687,00:37
4,1.5486,1.202357,0.706752,0.678939,00:36
5,1.472524,1.206962,0.705732,0.678722,00:37
6,1.410477,1.104214,0.762803,0.738972,00:37
7,1.372906,1.15141,0.732484,0.707998,00:37
8,1.321577,1.140536,0.72586,0.704802,00:37
9,1.298468,1.090199,0.75414,0.732391,00:37


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.955336,1.561201,0.567898,0.524088,00:36
1,1.738735,1.553803,0.552866,0.517241,00:37
2,1.623196,1.34015,0.652484,0.618645,00:37
3,1.545773,1.294398,0.663694,0.630293,00:37
4,1.455647,1.096946,0.761019,0.735773,00:37
5,1.443111,1.270516,0.682038,0.654321,00:36
6,1.39337,1.062295,0.77172,0.748748,00:36
7,1.338075,1.039122,0.793885,0.771854,00:37
8,1.363148,0.970578,0.815541,0.796672,00:36
9,1.294229,1.017477,0.808662,0.789017,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.915832,1.535021,0.58293,0.539474,00:36
1,1.701989,1.547704,0.545987,0.506756,00:36
2,1.621258,1.284162,0.671592,0.638186,00:37
3,1.559403,1.361559,0.635924,0.602454,00:37
4,1.507076,1.147235,0.734013,0.707003,00:36
5,1.440614,1.168669,0.73121,0.704571,00:37
6,1.397027,1.085545,0.755414,0.732357,00:36
7,1.338144,1.092655,0.769427,0.747164,00:36
8,1.324933,0.957648,0.823694,0.804674,00:37
9,1.321659,1.017686,0.796943,0.776234,00:37


# AvgAttnPooling2d Sandwich LA Gamma

In [None]:
class AttentionPool2d(nn.Module):
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.gamma_1 = nn.Parameter(1e-3 * torch.ones(ni))
        self.gamma_2 = nn.Parameter(1e-3 * torch.ones(ni))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.gamma_1*self.norm(self.pool(x).flatten(1) + self.attn(x, self.cls_q))
        return x + self.gamma_2*self.ffn(x)

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), norm(ni), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
        head[0]._init_weights(head[2])
    return head

In [None]:
train('XResNeXt50S AvgAttnPool Sandwich LAGamma', partial(xeca_resnext50s, custom_head=AvgAttnPoolHead), lr=3e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.976451,1.640859,0.517452,0.469733,00:36
1,1.753474,1.647473,0.487389,0.445263,00:36
2,1.647787,1.309432,0.660892,0.626744,00:36
3,1.560541,1.312274,0.648153,0.616792,00:37
4,1.515805,1.172659,0.724586,0.698031,00:36
5,1.471615,1.289837,0.66344,0.633806,00:36
6,1.402861,1.095267,0.772229,0.748869,00:36
7,1.356515,1.144584,0.736051,0.713652,00:36
8,1.31354,1.141087,0.719745,0.698321,00:37
9,1.284197,1.086503,0.756433,0.736565,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.968404,1.605534,0.544204,0.498612,00:36
1,1.74209,1.627252,0.523057,0.482294,00:36
2,1.618162,1.298611,0.672866,0.639012,00:36
3,1.547022,1.325226,0.647643,0.613006,00:36
4,1.4652,1.120047,0.750828,0.725498,00:36
5,1.441691,1.344132,0.649936,0.623071,00:36
6,1.400128,1.073664,0.777325,0.754396,00:36
7,1.343866,1.043854,0.787006,0.76523,00:37
8,1.35588,0.959078,0.821911,0.803343,00:37
9,1.30287,1.002414,0.81121,0.791551,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.914555,1.560852,0.55949,0.513795,00:36
1,1.700104,1.53527,0.541656,0.502348,00:36
2,1.615138,1.38571,0.608153,0.573892,00:37
3,1.55908,1.297184,0.672102,0.639324,00:37
4,1.515006,1.15985,0.739873,0.71472,00:36
5,1.438465,1.143023,0.753885,0.728247,00:36
6,1.392676,1.064309,0.772484,0.749675,00:36
7,1.331549,1.185868,0.716178,0.693835,00:36
8,1.330018,0.965958,0.812229,0.793384,00:36
9,1.326679,1.035955,0.783949,0.762572,00:36


# AvgAttnPooling2d Sandwich Act

In [None]:
class AttentionPool2d(nn.Module):
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.norm(self.pool(x).flatten(1) + self.attn(x, self.cls_q))
        return x + self.ffn(x)

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnPoolHead(ni, n_out, norm=nn.LayerNorm, act_cls=nn.GELU, ffn_expand=3, **kwargs):
    head = [AvgAttnPooling2d(ni, norm=norm, ffn_expand=ffn_expand, act_cls=act_cls, **kwargs), norm(ni), act_cls(), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
        head[0]._init_weights(head[2])
    return head

In [None]:
train('XResNeXt50S AvgAttnPool Sandwich Act', partial(xeca_resnext50s, custom_head=AvgAttnPoolHead), lr=3e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.993758,1.768419,0.408153,0.360851,00:37
1,1.780536,1.470875,0.602803,0.562844,00:36
2,1.688297,1.554974,0.556433,0.517529,00:36
3,1.59102,1.257413,0.681019,0.648142,00:37
4,1.567183,1.196406,0.721019,0.691918,00:36
5,1.520048,1.337408,0.653758,0.622428,00:36
6,1.440801,1.194196,0.716943,0.690687,00:37
7,1.380679,1.221442,0.715669,0.687785,00:36
8,1.348964,1.106338,0.746497,0.723193,00:38
9,1.304105,1.360989,0.62828,0.608918,00:37


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.965777,1.766941,0.422166,0.371248,00:37
1,1.787641,1.446763,0.617325,0.578101,00:37
2,1.677008,1.462388,0.580382,0.545867,00:37
3,1.590339,1.249559,0.687643,0.655194,00:37
4,1.517742,1.141485,0.742675,0.715691,00:37
5,1.488596,1.246402,0.705732,0.678556,00:36
6,1.447501,1.225977,0.695541,0.66811,00:36
7,1.3858,1.044294,0.772739,0.748452,00:37
8,1.399643,1.164508,0.723567,0.700778,00:36
9,1.352807,1.018094,0.793885,0.772729,00:37


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.925156,1.643934,0.488153,0.442966,00:36
1,1.737157,1.406209,0.624459,0.585216,00:36
2,1.639721,1.462125,0.602803,0.561711,00:37
3,1.606473,1.208929,0.727898,0.699057,00:37
4,1.533789,1.234806,0.699108,0.670167,00:36
5,1.481778,1.112186,0.750318,0.724076,00:37
6,1.454502,1.164672,0.724076,0.698544,00:36
7,1.371542,0.991654,0.808408,0.788118,00:37
8,1.351423,1.030589,0.782166,0.759946,00:37
9,1.360276,0.922578,0.838726,0.820883,00:37


# AvgAttnPooling2d SansFFN

In [None]:
class AttentionPool2d(nn.Module):
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm = norm(ni)
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        return self.norm(self.pool(x).flatten(1) + self.attn(x, self.cls_q))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
    return head

In [None]:
train('XResNeXt50S AvgAttnPool SansFFN', partial(xeca_resnext50s, custom_head=AvgAttnPoolHead), lr=3e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.993137,1.739832,0.46293,0.414448,00:36
1,1.800211,1.627798,0.499108,0.457815,00:36
2,1.696901,1.474545,0.570701,0.533047,00:36
3,1.590903,1.476235,0.589554,0.554821,00:36
4,1.560996,1.224947,0.700892,0.671637,00:36
5,1.522658,1.285386,0.674395,0.645825,00:36
6,1.439399,1.091211,0.769682,0.747101,00:36
7,1.385979,1.190017,0.715159,0.689798,00:36
8,1.339075,1.103412,0.744459,0.721941,00:37
9,1.313615,1.06294,0.773758,0.752879,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.052639,1.697377,0.492994,0.443693,00:36
1,1.798494,1.655149,0.487643,0.444709,00:36
2,1.677849,1.349411,0.655796,0.620522,00:36
3,1.601465,1.428925,0.605605,0.571187,00:36
4,1.506057,1.220415,0.698089,0.669613,00:36
5,1.491138,1.313922,0.666497,0.641735,00:36
6,1.4298,1.068625,0.778089,0.754404,00:36
7,1.380516,1.15379,0.729427,0.703897,00:36
8,1.400736,0.978433,0.812229,0.792724,00:36
9,1.329069,1.04018,0.784713,0.763426,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.982553,1.719822,0.472866,0.425997,00:36
1,1.75298,1.598225,0.52,0.48213,00:36
2,1.657902,1.422863,0.610191,0.573986,00:36
3,1.604762,1.373176,0.647389,0.610598,00:36
4,1.527552,1.259165,0.69758,0.665941,00:36
5,1.483963,1.265721,0.684076,0.65499,00:36
6,1.439722,1.109163,0.757707,0.733464,00:36
7,1.38429,1.177143,0.752611,0.726772,00:37
8,1.346938,0.962388,0.820637,0.801728,00:36
9,1.344397,0.98974,0.803312,0.783029,00:36


# AvgAttnPooling2d SansFFN Act

In [None]:
class AttentionPool2d(nn.Module):
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm = norm(ni)
        self.act = act_cls()
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        return self.act(self.norm(self.pool(x).flatten(1) + self.attn(x, self.cls_q)))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
    return head

In [None]:
train('XResNeXt50S AvgAttnPool SansFFN Act', partial(xeca_resnext50s, custom_head=AvgAttnPoolHead), lr=4e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.943459,1.710741,0.477962,0.430979,00:36
1,1.767894,1.408009,0.629299,0.589822,00:36
2,1.666107,1.371624,0.630318,0.594754,00:36
3,1.621284,1.493386,0.55949,0.522307,00:36
4,1.579743,1.381727,0.608662,0.576331,00:36
5,1.519817,1.328818,0.659363,0.633444,00:36
6,1.429425,1.197822,0.725605,0.697906,00:36
7,1.374951,1.178558,0.712357,0.688529,00:36
8,1.335008,1.062739,0.767389,0.744348,00:37
9,1.303676,1.025105,0.786752,0.766287,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.971528,1.922931,0.35949,0.312536,00:36
1,1.773244,1.427597,0.62242,0.583868,00:36
2,1.649861,1.522328,0.553121,0.518632,00:36
3,1.592307,1.274694,0.675414,0.642877,00:36
4,1.508863,1.192602,0.706752,0.67942,00:37
5,1.497391,1.280398,0.671338,0.64261,00:36
6,1.424233,1.141105,0.740637,0.715364,00:36
7,1.378681,1.178343,0.716688,0.690015,00:36
8,1.366217,1.071814,0.761019,0.738461,00:36
9,1.329481,1.042776,0.793631,0.7716,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.969082,1.596342,0.513121,0.464825,00:36
1,1.739893,1.425653,0.619873,0.580655,00:36
2,1.639152,1.510048,0.572994,0.538955,00:36
3,1.600888,1.246637,0.690191,0.660315,00:36
4,1.540677,1.197583,0.698344,0.667892,00:36
5,1.464658,1.06366,0.780892,0.757255,00:36
6,1.440815,1.072677,0.767389,0.74384,00:36
7,1.359781,1.137641,0.750828,0.725483,00:36
8,1.35776,1.081392,0.755924,0.734366,00:36
9,1.340936,1.054572,0.764331,0.741895,00:36


# AvgAttnConcatPooling2d SansFFN

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnConcatPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm1 = norm(ni)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm2 = norm(ni)
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        return torch.cat([self.norm2(self.pool(x).flatten(1)), self.norm1(self.attn(x, self.cls_q))], dim=1)

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnConcatPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnConcatPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), nn.Linear(2*ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
    return head

In [None]:
train('XResNeXt50S AvgAttnConcatPool SansFFN', partial(xeca_resnext50s, custom_head=AvgAttnConcatPoolHead), lr=3e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.366971,2.260676,0.197251,0.108498,00:35
1,2.275674,2.22739,0.249427,0.169591,00:35
2,2.190456,2.033031,0.276152,0.199662,00:35
3,2.107399,2.086015,0.259099,0.181966,00:35
4,2.054559,1.810588,0.424536,0.360386,00:35
5,1.97926,1.863871,0.388394,0.331342,00:35
6,1.935258,1.730259,0.461441,0.404557,00:35
7,1.841612,1.669784,0.484093,0.429073,00:35
8,1.860053,1.556025,0.550522,0.501578,00:35
9,1.736675,1.61654,0.490456,0.44547,00:35


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.339253,2.268467,0.210741,0.134237,00:35
1,2.241684,2.297435,0.221176,0.141078,00:34
2,2.255845,2.059565,0.273352,0.193018,00:34
3,2.2135,2.163832,0.225248,0.144836,00:35
4,2.113716,1.963778,0.31942,0.246854,00:35
5,2.048318,2.013367,0.300076,0.228465,00:35
6,1.998665,1.786578,0.421736,0.358215,00:35
7,1.988235,2.019295,0.317129,0.255118,00:35
8,1.970553,1.860774,0.39323,0.327146,00:35
9,1.909728,1.673725,0.482566,0.430664,00:35


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.333884,2.383904,0.185798,0.121208,00:35
1,2.253497,2.220833,0.235683,0.161632,00:35
2,2.091516,1.896108,0.365487,0.296412,00:35
3,2.040098,2.004278,0.318147,0.256528,00:34
4,1.973223,1.74766,0.431662,0.37292,00:35
5,1.900835,1.779935,0.424026,0.374776,00:35
6,1.832726,1.719355,0.463222,0.408671,00:34
7,1.801134,1.667203,0.482566,0.430705,00:35
8,1.731838,1.391216,0.619496,0.576751,00:35
9,1.640953,1.459234,0.587172,0.544438,00:35


# AvgAttnConcatPooling2d SansFFN Act

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnConcatPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm1 = norm(ni)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm2 = norm(ni)
        self.act = act_cls()
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        return self.act(torch.cat([self.norm2(self.pool(x).flatten(1)), self.norm1(self.attn(x, self.cls_q))], dim=1))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnConcatPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnConcatPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), nn.Linear(2*ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
    return head

In [None]:
train('XResNeXt50S AvgAttnConcatPool SansFFN Act', partial(xeca_resnext50s, custom_head=AvgAttnConcatPoolHead), lr=2e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.270928,2.197875,0.227793,0.147633,00:36
1,2.213936,2.149503,0.235429,0.154534,00:35
2,2.18476,2.132705,0.236192,0.165345,00:35
3,2.070524,1.925798,0.348689,0.279598,00:36
4,2.012618,1.844382,0.389667,0.32421,00:35
5,1.915658,1.689243,0.463477,0.404112,00:35
6,1.856117,1.932482,0.374905,0.321663,00:35
7,1.793912,1.55931,0.528633,0.477962,00:35
8,1.757324,1.615135,0.502418,0.455416,00:36
9,1.674246,1.385472,0.616442,0.574376,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.266019,2.118241,0.248918,0.167438,00:35
1,2.195801,2.051439,0.25859,0.175555,00:35
2,2.155887,2.026029,0.283787,0.207453,00:35
3,2.03515,2.083324,0.296259,0.227823,00:35
4,1.947282,1.814831,0.388394,0.32386,00:35
5,1.88697,2.026854,0.317893,0.258409,00:35
6,1.823382,1.607231,0.516925,0.465391,00:35
7,1.773249,1.695145,0.471621,0.421721,00:35
8,1.712179,1.454309,0.588954,0.5439,00:36
9,1.695325,1.544664,0.538305,0.495257,00:35


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.248329,2.152552,0.230848,0.147614,00:35
1,2.211167,2.137051,0.244592,0.162438,00:35
2,2.109488,1.967365,0.321201,0.247905,00:35
3,2.021063,1.970587,0.314075,0.252719,00:35
4,2.046623,1.969465,0.315347,0.244113,00:35
5,1.948036,1.769905,0.42479,0.364091,00:35
6,1.899768,1.704264,0.457368,0.39773,00:35
7,1.843934,1.758105,0.425299,0.371699,00:35
8,1.825862,1.545109,0.554085,0.505592,00:35
9,1.751622,1.554195,0.537032,0.488034,00:35


# AvgAttnConcatPooling2d Sandwich

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnConcatPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm1 = norm(ni)
        self.norm2 = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        self.norm3 = norm(ni)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm4 = norm(ni)
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        a = self.cls_q + self.norm1(self.attn(x, self.cls_q))
        a = a + self.ffn(self.norm2(a))
        return torch.cat([self.norm4(self.pool(x).flatten(1)), self.norm3(a)], dim=1)

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnConcatPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnConcatPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), nn.Linear(2*ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
    return head

In [None]:
train('XResNeXt50S AvgAttnConcatPool Sandwich', partial(xeca_resnext50s, custom_head=AvgAttnConcatPoolHead), lr=2e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.951038,1.623678,0.52535,0.478966,00:36
1,1.802131,1.765291,0.484586,0.445918,00:36
2,1.675112,1.340351,0.652994,0.619248,00:36
3,1.579748,1.451055,0.593121,0.554759,00:36
4,1.625115,1.342645,0.654267,0.619089,00:36
5,1.540457,1.273389,0.673885,0.645848,00:36
6,1.569982,1.369295,0.641783,0.610077,00:36
7,1.511734,1.169572,0.718726,0.689036,00:37
8,1.439583,1.215142,0.707771,0.678567,00:37
9,1.409969,1.126261,0.73707,0.710522,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.03128,1.688897,0.484076,0.438103,00:36
1,1.891696,1.720521,0.469809,0.418508,00:36
2,1.753306,1.609591,0.528662,0.487945,00:36
3,1.703121,1.553074,0.536815,0.490251,00:36
4,1.737858,1.631133,0.534777,0.490032,00:37
5,1.678609,1.31661,0.664713,0.630311,00:36
6,1.558489,1.461343,0.596943,0.572001,00:36
7,1.523184,1.218567,0.710318,0.679368,00:37
8,1.546395,1.550316,0.54293,0.509426,00:36
9,1.511101,1.230744,0.688153,0.661286,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.038269,1.809314,0.455287,0.409405,00:36
1,1.81487,1.8524,0.440764,0.392349,00:36
2,1.773658,1.549861,0.564841,0.523447,00:37
3,1.711823,1.678691,0.492229,0.442465,00:36
4,1.746901,1.489584,0.572994,0.533448,00:36
5,1.679086,1.331856,0.66293,0.628028,00:36
6,1.599758,1.316918,0.664713,0.630226,00:36
7,1.528476,1.182904,0.723057,0.693911,00:36
8,1.490244,1.216561,0.703949,0.673425,00:36
9,1.460848,1.143267,0.735032,0.707333,00:36


# AvgAttnConcatPooling2d Sandwich Act

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnConcatPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm1 = norm(ni)
        self.norm2 = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        self.norm3 = norm(ni)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm4 = norm(ni)
        self.act = act_cls()
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        a = self.cls_q + self.norm1(self.attn(x, self.cls_q))
        a = a + self.ffn(self.norm2(a))
        return self.act(torch.cat([self.norm4(self.pool(x).flatten(1)), self.norm3(a)], dim=1))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnConcatPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnConcatPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), nn.Linear(2*ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
    return head

In [None]:
train('XResNeXt50S AvgAttnConcatPool Sandwich Act', partial(xeca_resnext50s, custom_head=AvgAttnConcatPoolHead), lr=4e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.915825,1.684322,0.500382,0.456674,00:36
1,1.742307,1.88856,0.441783,0.396051,00:36
2,1.731655,1.490908,0.567898,0.527589,00:36
3,1.583261,1.604922,0.530446,0.495665,00:36
4,1.55649,1.229888,0.70293,0.672327,00:36
5,1.502965,1.830608,0.458599,0.414735,00:36
6,1.436867,1.124148,0.751083,0.725341,00:36
7,1.411996,1.104097,0.749299,0.722754,00:36
8,1.387518,1.077315,0.764076,0.738495,00:37
9,1.347494,1.083517,0.764076,0.740882,00:36


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,2.063851,1.666965,0.494522,0.441212,00:38
1,1.943247,1.795101,0.438217,0.386953,00:37
2,1.823075,1.467947,0.595669,0.554758,00:38
3,1.664064,1.504277,0.558217,0.525143,00:37
4,1.560798,1.258227,0.687898,0.656272,00:37
5,1.526358,1.416314,0.61707,0.5871,00:37
6,1.450396,1.148999,0.729682,0.703887,00:37
7,1.389304,1.186203,0.717707,0.689122,00:37
8,1.383999,1.043414,0.78293,0.760481,00:37
9,1.307318,1.002378,0.810191,0.790203,00:37


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.989354,1.714458,0.481274,0.434292,00:37
1,1.727813,1.694871,0.522803,0.48157,00:36
2,1.783363,1.467192,0.576815,0.533395,00:37
3,1.664021,1.474582,0.584459,0.544089,00:37
4,1.575259,1.295711,0.663185,0.631028,00:36
5,1.509671,1.241149,0.700382,0.669671,00:37
6,1.454965,1.100834,0.76,0.733954,00:37
7,1.375467,1.198036,0.70293,0.675948,00:37
8,1.350944,1.015136,0.792357,0.76999,00:37
9,1.327655,1.113034,0.750318,0.725944,00:37


# XResNeXt50S 4e-3

In [None]:
train('XResNeXt50S', xeca_resnext50s, lr=4e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.880327,2.758565,0.34293,0.312304,00:36
1,1.680574,1.364445,0.651974,0.614797,00:35
2,1.567609,1.266526,0.696051,0.664818,00:35
3,1.480771,1.168921,0.73707,0.709101,00:35
4,1.438842,1.167073,0.732739,0.708076,00:34
5,1.393064,1.080551,0.778599,0.755642,00:35
6,1.313726,1.123547,0.745732,0.722022,00:35
7,1.292526,1.003577,0.805096,0.785443,00:35
8,1.23386,1.080971,0.754904,0.733628,00:36
9,1.204403,0.866798,0.861656,0.846729,00:35


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.941069,1.62896,0.524331,0.475258,00:35
1,1.70525,1.367437,0.655287,0.619816,00:35
2,1.551901,1.335686,0.657834,0.625667,00:35
3,1.475672,1.168957,0.727643,0.698909,00:35
4,1.392535,1.145185,0.747006,0.720952,00:35
5,1.379348,1.045446,0.780892,0.758509,00:35
6,1.316699,1.321466,0.652229,0.633151,00:35
7,1.268308,0.965607,0.825987,0.807501,00:35
8,1.279191,1.098854,0.763822,0.74279,00:35
9,1.221056,0.918143,0.842548,0.825988,00:35


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.932623,2.012054,0.416561,0.36327,00:35
1,1.703759,1.413061,0.626497,0.58775,00:35
2,1.572012,1.648814,0.534013,0.497901,00:35
3,1.493195,1.162035,0.740892,0.712919,00:35
4,1.431825,1.18904,0.727643,0.70223,00:35
5,1.377009,1.041309,0.792611,0.769856,00:35
6,1.340304,1.078979,0.767389,0.745437,00:35
7,1.274392,0.953758,0.831338,0.813277,00:35
8,1.255201,1.014942,0.802038,0.783343,00:35
9,1.243729,0.891011,0.858089,0.843205,00:35


# XResNeXt50S 2e-3

In [None]:
train('XResNeXt50S', xeca_resnext50s, lr=2e-3)

epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.933335,1.737519,0.484586,0.436084,00:35
1,1.719851,1.41021,0.644841,0.606267,00:35
2,1.587805,1.369821,0.649427,0.61571,00:35
3,1.508701,1.203386,0.732739,0.703493,00:35
4,1.471549,1.190647,0.725096,0.697941,00:34
5,1.429278,1.107611,0.768153,0.74369,00:35
6,1.361824,1.123043,0.753885,0.729566,00:35
7,1.338279,1.023283,0.795669,0.774048,00:35
8,1.285821,1.071675,0.77707,0.755276,00:36
9,1.250277,0.918175,0.842548,0.825346,00:35


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.985417,1.710314,0.499363,0.445541,00:35
1,1.749766,1.408817,0.64,0.601164,00:36
2,1.585358,1.360793,0.653503,0.619461,00:35
3,1.510056,1.219305,0.713121,0.682327,00:35
4,1.429974,1.194578,0.725605,0.697712,00:35
5,1.428774,1.107466,0.757197,0.731253,00:35
6,1.371912,1.443264,0.627261,0.611333,00:35
7,1.325946,1.034716,0.792866,0.770657,00:37
8,1.331943,1.106941,0.762803,0.740897,00:35
9,1.270881,0.974661,0.820892,0.802286,00:35


epoch,train_loss,valid_loss,accuracy,matthews_corrcoef,time
0,1.993358,1.733775,0.483822,0.431724,00:35
1,1.734066,1.440217,0.621147,0.580864,00:35
2,1.604012,1.495426,0.601529,0.565892,00:36
3,1.526832,1.2172,0.713631,0.682923,00:36
4,1.474202,1.238736,0.700637,0.671074,00:35
5,1.428157,1.112174,0.760255,0.734357,00:36
6,1.394197,1.106147,0.767134,0.742719,00:35
7,1.336922,1.024109,0.796178,0.774631,00:35
8,1.315803,1.073936,0.77656,0.753523,00:35
9,1.302023,0.973247,0.819618,0.799808,00:35
