# Logging

In [None]:
%env WANDB_SILENT=true
import wandb
wandb.login()

env: WANDB_SILENT=true


True

# Imports

In [None]:
from __future__ import annotations
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback
from fastxtend.vision.all import *
from tfrecord.reader import example_loader

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# AttentionPooling XResNet

In [None]:
def xeca_resnext50s(n_out=1000,  **kwargs):  
    return XResNet(ECAResNeXtBlock, 4, [3, 4, 9], block_szs=[64, 128, 256], n_out=n_out, **kwargs)

# Petals

In [None]:
petals_stats = ([0.453,0.415,0.306], [0.282,0.245,0.272])

def get_petals_dls(size, bs, item_tfms=None, batch_tfms=[], workers=None, stats=True):
    if size>=256: path = '512'
    if size<=224: path = '224'
            
    data,fns = [], []
    for folder in ['train', 'val']:
        fns += get_files(f'/home/benja/Data/Petals/{path}/{folder}', extensions='.tfrec')
    for fn in fns:
        r = example_loader(str(fn), description={"image": "byte", "class": "int"}, index_path = None)
        for sample in r:
            data.append([sample['image'], sample['class'][0]])
            
    def get_items(*args, **kwargs):
        return data

    get_x = lambda o: PILImage.create(io.BytesIO(o[0]))
    get_y = lambda o: o[1]
    splitter = IndexSplitter(range(12753, len(data)))

    if workers is None: workers = min(8, num_cpus())
    if stats: batch_tfms += [Normalize.from_stats(*petals_stats)]
    if item_tfms is None: item_tfms = RandomResizedCrop(size, min_scale=0.35)

    block = DataBlock(blocks=(ImageBlock, CategoryBlock),
                      get_items=get_items,
                      get_x=get_x,
                      get_y=get_y,
                      splitter=splitter,
                      item_tfms=item_tfms,
                      batch_tfms=batch_tfms)

    return block.dataloaders([''], bs=bs, num_workers=workers)

# Training Code

In [None]:
def train(run_name, model, lr, epochs=80, bs=64, size=224, loss_func=LabelSmoothingCrossEntropyFlat(), 
          batch_tfms=aug_transforms(max_zoom=1, flip_vert=True, max_rotate=45, xtra_tfms=[Hue(), Saturation()]), 
          seeds=[42,314,1618], valid_sizes=[256,384], log=True):

    if log:
        run_results, broke = None, False
        run = wandb.init(project="AttentionPooling", name=f'{run_name} lr={lr}', group=f'Petals {epochs}E')
    try:
        for seed in seeds:
            try:
                with less_random(seed):
                    dls = get_petals_dls(size, bs, batch_tfms=batch_tfms)

                cbs = [WandbCallback(log=None, log_preds=False, log_model=False)] if log else []
                with less_random(seed):
                    learn = Learner(dls, model(n_out=dls.c, act_cls=nn.Mish, stem_pool=MaxBlurPool, block_pool=BlurPool), 
                                    loss_func=loss_func, opt_func=ranger,
                                    metrics=[BalancedAccuracy(), MatthewsCorrCoef(), F1Score(average='macro')], 
                                    cbs=[CutMixUpAugment(), TerminateOnTrainNaN()]+cbs).to_fp16()

                with less_random(seed):
                    learn.fit_flat_cos(epochs, lr) 
                    learn.save(f'{run_name} lr={lr} seed={seed}', with_opt=False)

                if log:
                    if run_results is None:
                        run_results = {n:[v] for n,v in zip(learn.recorder.metric_names, learn.recorder.log) if n not in ['epoch', 'time']}
                    else:
                        for n,v in zip(learn.recorder.metric_names, learn.recorder.log):
                            if n not in ['epoch', 'time']: run_results[n].append(v)

                if not is_listy(valid_sizes):
                    valid_sizes = [valid_sizes]

                if log: learn.remove_cb(WandbCallback)
                for vs in valid_sizes:
                    dls = None
                    learn.dls = get_petals_dls(vs, 48)
                    loss, ba, matthews, f1 = learn.validate()
                    if log:
                        for n,v in zip([f'loss_{vs}', f'balanced_accuracy_{vs}', f'matthews_{vs}', f'f1score_{vs}'], [loss, ba, matthews, f1]):
                            if n in run_results.keys(): run_results[n].append(v)
                            else: run_results[n] = [v]
                    else:
                        print(f'f1score_{vs}: {f1}, matthews_{vs}: {matthews}, balanced_accuracy_{vs}: {ba}')
            except CancelFitException:
                print('Terminate on NAN')
                broke = True
                raise
            finally:
                free_gpu_memory(learn, dls)
            
    finally:
        if log:
            if not broke:
                for n in run_results.keys():
                    run.summary[f'{n}_mean'] = np.mean(run_results[n])
                    run.summary[f'{n}_std'] = np.std(run_results[n])
            run.finish()

# XResNeXt50

In [None]:
train('XResNeXt50', xeca_resnext50, lr=8e-3)

epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.8566,3.512599,0.080187,0.180024,0.070097,00:42
1,3.602139,3.21703,0.120199,0.239301,0.107326,00:38
2,3.34809,2.900763,0.219084,0.325703,0.209999,00:39
3,3.160113,2.63855,0.32078,0.416993,0.316223,00:38
4,2.921081,2.503701,0.393555,0.474877,0.398531,00:39
5,2.765812,2.086645,0.506675,0.599919,0.517172,00:38
6,2.668262,2.005182,0.552763,0.639725,0.558163,00:39
7,2.562153,1.863424,0.617738,0.681586,0.631275,00:38
8,2.409245,1.820868,0.672775,0.699983,0.668074,00:39
9,2.365008,1.712467,0.695297,0.73925,0.711368,00:38


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.825866,3.524153,0.105167,0.178139,0.080884,00:38
1,3.610751,3.225566,0.126238,0.232832,0.116152,00:38
2,3.434928,3.00141,0.19138,0.304691,0.179499,00:39
3,3.214,2.777242,0.264963,0.374399,0.249824,00:39
4,2.982632,2.509317,0.37884,0.45634,0.362172,00:38
5,2.772416,2.084107,0.492886,0.599698,0.495893,00:38
6,2.631877,1.957457,0.570337,0.652927,0.582466,00:38
7,2.474357,1.84299,0.608265,0.692485,0.62925,00:39
8,2.431706,1.815778,0.632535,0.702546,0.654213,00:39
9,2.340867,1.75065,0.697399,0.735557,0.698868,00:39


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.81015,3.399865,0.085222,0.191727,0.07421,00:38
1,3.59465,3.203567,0.165313,0.250792,0.134376,00:38
2,3.374777,3.17485,0.162214,0.258071,0.136454,00:38
3,3.085263,2.685839,0.307469,0.415685,0.293421,00:39
4,2.93106,2.440459,0.411722,0.489225,0.407765,00:38
5,2.708336,2.058211,0.517862,0.614685,0.513513,00:39
6,2.607032,1.907588,0.605726,0.675034,0.612585,00:38
7,2.515226,1.869227,0.61621,0.688417,0.631403,00:38
8,2.369284,1.724864,0.678232,0.737018,0.701371,00:39
9,2.321344,1.731684,0.689709,0.730617,0.703869,00:38


# XResNeXt50S

In [None]:
train('XResNeXt50S', xeca_resnext50s, lr=8e-3)

epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.7494,3.385743,0.085055,0.194243,0.074018,00:37
1,3.509467,3.124067,0.151909,0.27399,0.144027,00:37
2,3.290705,2.858089,0.220578,0.337046,0.215158,00:38
3,3.063619,2.595628,0.32977,0.436382,0.32833,00:37
4,2.820647,2.378977,0.417245,0.516086,0.43514,00:37
5,2.680211,2.014882,0.542583,0.630336,0.557888,00:37
6,2.568132,1.958049,0.582177,0.658439,0.592892,00:37
7,2.469881,1.796526,0.657379,0.714067,0.673701,00:37
8,2.319496,1.753216,0.690368,0.732329,0.700253,00:38
9,2.276789,1.680592,0.704108,0.75481,0.722485,00:37


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.735625,3.381425,0.113598,0.207237,0.094709,00:37
1,3.510208,3.41685,0.146181,0.250784,0.146362,00:37
2,3.282134,2.865857,0.246657,0.357401,0.249151,00:37
3,3.081913,2.630831,0.323939,0.436208,0.332236,00:37
4,2.845621,2.410048,0.419571,0.49198,0.410478,00:37
5,2.644345,1.962214,0.562358,0.650468,0.570045,00:37
6,2.514336,1.883592,0.602377,0.687954,0.623549,00:37
7,2.388488,1.775987,0.649724,0.713042,0.673095,00:38
8,2.36179,1.754533,0.631185,0.714517,0.65818,00:38
9,2.265226,1.725445,0.695722,0.738647,0.707863,00:38


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.717364,3.462643,0.088278,0.184138,0.078848,00:37
1,3.523117,3.083325,0.206404,0.298242,0.190166,00:37
2,3.256004,2.874787,0.235252,0.35075,0.213148,00:37
3,2.962602,2.519925,0.35377,0.462455,0.34741,00:38
4,2.833249,2.364273,0.411959,0.507357,0.416552,00:37
5,2.619587,1.992729,0.570329,0.646455,0.582587,00:38
6,2.537615,1.85104,0.60294,0.692289,0.624668,00:37
7,2.446482,1.763961,0.657663,0.724919,0.692122,00:37
8,2.304383,1.69019,0.706019,0.740614,0.717159,00:38
9,2.260761,1.706365,0.70685,0.744978,0.724533,00:37


# Learned Aggregation

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class LearnedAggregation(nn.Module):
    "Learned Aggregation from https://arxiv.org/abs/2112.13692"
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.gamma_1 = nn.Parameter(1e-4 * torch.ones(ni))
        self.gamma_2 = nn.Parameter(1e-4 * torch.ones(ni))
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.cls_q + self.gamma_1 * self.attn(x, self.cls_q)
        return x + self.gamma_2 * self.ffn(self.norm(x))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def LearnedAggregationHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [LearnedAggregation(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), norm(ni), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
        head[0]._init_weights(head[2])
    return head

In [None]:
train('XResNeXt50S LearnedAggregation', partial(xeca_resnext50s, custom_head=LearnedAggregationHead), lr=6e-4)

epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.942313,3.592682,0.042077,0.142251,0.026549,00:43
1,3.732117,3.322711,0.087226,0.219985,0.072196,00:39
2,3.521841,3.052122,0.147268,0.286151,0.130992,00:40
3,3.401019,2.928377,0.212953,0.348127,0.197863,00:39
4,3.260453,2.819321,0.244789,0.359461,0.225787,00:40
5,3.144395,2.570801,0.296524,0.434015,0.289368,00:39
6,3.068087,2.474399,0.349186,0.473645,0.349384,00:40
7,2.977717,2.345891,0.405898,0.522798,0.40556,00:39
8,2.857626,2.297814,0.433149,0.516927,0.428497,00:39
9,2.816922,2.20335,0.463086,0.562239,0.47406,00:39


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.856097,3.437283,0.080388,0.194379,0.056389,00:40
1,3.591661,3.134887,0.142923,0.274214,0.126617,00:40
2,3.442032,3.049665,0.169115,0.289987,0.154417,00:40
3,3.34175,2.822592,0.237758,0.361402,0.221535,00:40
4,3.228452,2.721284,0.303371,0.39958,0.295359,00:40
5,3.087149,2.493648,0.35507,0.465309,0.355532,00:40
6,3.004003,2.449016,0.355154,0.479693,0.355504,00:40
7,2.916291,2.34257,0.405862,0.51344,0.418482,00:40
8,2.895232,2.282868,0.431245,0.5236,0.433221,00:40
9,2.808582,2.260231,0.47881,0.536646,0.476733,00:40


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.974149,3.618375,0.042105,0.136252,0.027099,00:40
1,3.818916,3.423566,0.08086,0.189915,0.061968,00:40
2,3.644516,3.205244,0.12791,0.247141,0.101813,00:40
3,3.509454,3.09827,0.151274,0.28077,0.137177,00:41
4,3.40749,2.950075,0.230304,0.322591,0.193595,00:40
5,3.269278,2.705898,0.276557,0.393536,0.25859,00:40
6,3.1857,2.605543,0.32309,0.43763,0.309491,00:40
7,3.104107,2.465721,0.353705,0.468659,0.361518,00:40
8,2.968607,2.408585,0.36804,0.487608,0.369997,00:40
9,2.947391,2.305183,0.423172,0.528993,0.421511,00:40


# Learned Aggregation Sandwich

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class LearnedAggregation(nn.Module):
    "Learned Aggregation from https://arxiv.org/abs/2112.13692"
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.gamma_1 = nn.Parameter(1e-4 * torch.ones(ni))
        self.gamma_2 = nn.Parameter(1e-4 * torch.ones(ni))
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm1 = norm(ni)
        self.norm2 = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.cls_q + self.gamma_1 * self.norm1(self.attn(x, self.cls_q))
        return x + self.gamma_2 * self.ffn(self.norm2(x))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def LearnedAggregationHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [LearnedAggregation(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), norm(ni), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
        head[0]._init_weights(head[2])
    return head

In [None]:
train('XResNeXt50S LearnedAggregation Sandwich', partial(xeca_resnext50s, custom_head=LearnedAggregationHead), lr=1e-3)

epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.721831,3.206274,0.10974,0.242272,0.093319,00:43
1,3.464206,2.949641,0.192646,0.32392,0.17806,00:39
2,3.321413,2.784479,0.246272,0.382313,0.238564,00:40
3,3.179252,2.685165,0.310478,0.409213,0.295287,00:39
4,3.048018,2.51644,0.354663,0.46075,0.342694,00:40
5,2.926201,2.454544,0.359786,0.479149,0.359102,00:39
6,2.861755,2.384737,0.391944,0.498127,0.385338,00:40
7,2.775338,2.118307,0.494504,0.590268,0.503494,00:39
8,2.654123,2.014692,0.537062,0.624492,0.537462,00:40
9,2.607427,1.960675,0.54345,0.637889,0.564053,00:39


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.722145,3.23109,0.116842,0.235692,0.094406,00:40
1,3.532604,2.999215,0.181482,0.314402,0.173372,00:40
2,3.350199,2.825323,0.248615,0.371363,0.233697,00:40
3,3.232605,2.674874,0.293835,0.416043,0.284692,00:40
4,3.10767,2.529971,0.358653,0.455817,0.342723,00:40
5,2.963178,2.435473,0.363745,0.482778,0.359603,00:40
6,2.858174,2.394844,0.386242,0.510391,0.40651,00:40
7,2.759888,2.142597,0.474853,0.58319,0.493213,00:40
8,2.728804,2.048705,0.508297,0.611868,0.521129,00:40
9,2.647335,2.072525,0.529473,0.609094,0.541538,00:40


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.818058,3.325754,0.086618,0.205294,0.07071,00:40
1,3.566523,3.064178,0.18093,0.294713,0.163282,00:40
2,3.376781,2.84243,0.246323,0.363262,0.218115,00:40
3,3.229154,2.811182,0.254306,0.367037,0.244249,00:41
4,3.134282,2.622929,0.314724,0.426566,0.296995,00:40
5,2.979655,2.491088,0.351472,0.465526,0.351674,00:40
6,2.911321,2.287581,0.42532,0.533588,0.43156,00:40
7,2.831218,2.140589,0.476795,0.572007,0.485233,00:40
8,2.691633,2.06234,0.500676,0.603852,0.515337,00:40
9,2.669273,2.023359,0.524915,0.610904,0.536077,00:40


# AvgAttnPooling2d Sandwich

In [None]:
class AttentionPool2d(nn.Module):
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.norm(self.pool(x).flatten(1) + self.attn(x, self.cls_q))
        return x + self.ffn(x)

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), norm(ni), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
        head[0]._init_weights(head[2])
    return head

In [None]:
train('XResNeXt50S AvgAttnPool Sandwich', partial(xeca_resnext50s, custom_head=AvgAttnPoolHead), lr=2e-3)

epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.760755,3.252065,0.12775,0.247936,0.114464,00:44
1,3.551856,3.032653,0.162035,0.287298,0.152099,00:40
2,3.445181,2.943063,0.190224,0.317008,0.173,00:41
3,3.321381,2.869853,0.232946,0.357193,0.215592,00:40
4,3.206003,2.782686,0.272751,0.384319,0.259484,00:41
5,3.09884,2.751137,0.263933,0.395998,0.235948,00:40
6,2.995364,2.357382,0.394024,0.512329,0.381527,00:41
7,2.904516,2.22121,0.429372,0.558831,0.436485,00:40
8,2.7664,2.112513,0.493965,0.590655,0.484846,00:41
9,2.724129,2.1304,0.469336,0.580526,0.485744,00:41


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.783049,3.354963,0.109987,0.207081,0.092442,00:41
1,3.608206,3.11069,0.164026,0.277789,0.151793,00:41
2,3.455051,3.036899,0.179081,0.301405,0.163245,00:41
3,3.357688,2.870399,0.204659,0.339305,0.192165,00:41
4,3.250488,2.71332,0.268815,0.388337,0.249501,00:41
5,3.103336,2.557348,0.296149,0.438671,0.281478,00:41
6,2.999174,2.401103,0.342483,0.488177,0.339455,00:41
7,2.907992,2.295131,0.405737,0.522385,0.41059,00:41
8,2.882618,2.198345,0.419223,0.553957,0.420801,00:41
9,2.788747,2.223221,0.455595,0.556003,0.4594,00:41


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.797479,3.290395,0.129235,0.223914,0.109852,00:41
1,3.627413,3.144405,0.176662,0.275519,0.137784,00:41
2,3.469552,3.013371,0.187556,0.310749,0.162744,00:41
3,3.310723,2.938065,0.232593,0.343449,0.213211,00:41
4,3.259302,2.836236,0.24828,0.34885,0.225597,00:41
5,3.111783,2.745301,0.254419,0.376662,0.242319,00:41
6,3.039078,2.403023,0.359117,0.491821,0.333122,00:41
7,2.957411,2.259984,0.411718,0.53535,0.414067,00:41
8,2.825278,2.214396,0.43792,0.558146,0.443798,00:41
9,2.792737,2.10642,0.486123,0.586951,0.488519,00:41


# AvgAttnPooling2d Sandwich Act

In [None]:
class AttentionPool2d(nn.Module):
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.norm(self.pool(x).flatten(1) + self.attn(x, self.cls_q))
        return x + self.ffn(x)

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnPoolHead(ni, n_out, norm=nn.LayerNorm, act_cls=nn.GELU, ffn_expand=3, **kwargs):
    head = [AvgAttnPooling2d(ni, norm=norm, ffn_expand=ffn_expand, act_cls=act_cls, **kwargs), norm(ni), act_cls(), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
        head[0]._init_weights(head[2])
    return head

In [None]:
train('XResNeXt50S AvgAttnPool Sandwich Act', partial(xeca_resnext50s, custom_head=AvgAttnPoolHead), lr=3e-3)

epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.801355,3.367344,0.088959,0.206338,0.066809,00:44
1,3.634204,3.243986,0.097912,0.223982,0.078516,00:40
2,3.510776,3.170869,0.12365,0.249913,0.11394,00:41
3,3.415717,3.036126,0.182118,0.304849,0.156121,00:40
4,3.27703,2.905838,0.238797,0.347042,0.202473,00:41
5,3.152895,2.973745,0.189017,0.318272,0.167384,00:41
6,3.070151,2.451154,0.330643,0.47403,0.320171,00:41
7,2.971345,2.254404,0.419086,0.536695,0.415726,00:40
8,2.855281,2.259104,0.425111,0.529185,0.406666,00:41
9,2.817854,2.238568,0.413325,0.5397,0.415149,00:41


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.817756,3.424497,0.079523,0.164901,0.053603,00:41
1,3.668667,3.338164,0.10938,0.21763,0.09658,00:41
2,3.518888,3.172177,0.13588,0.252874,0.122349,00:41
3,3.445122,3.07779,0.175057,0.29045,0.147171,00:41
4,3.325016,2.870094,0.23102,0.331427,0.201474,00:41
5,3.180367,2.757786,0.255971,0.385736,0.23912,00:41
6,3.050837,2.463984,0.331523,0.463651,0.330703,00:41
7,2.983763,2.344941,0.379181,0.50579,0.384923,00:41
8,2.969962,2.300706,0.391729,0.517289,0.378785,00:41
9,2.869169,2.253119,0.450659,0.549921,0.441141,00:41


wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectTimeout), entering retry loop.


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.832584,3.346924,0.089421,0.196859,0.073593,00:47
1,3.652154,3.268334,0.141004,0.23754,0.108406,00:47
2,3.513377,3.073962,0.152952,0.280028,0.124869,00:46
3,3.431386,3.01211,0.190357,0.314349,0.168748,00:47
4,3.360487,2.964052,0.217401,0.317936,0.184563,00:47
5,3.219356,2.813107,0.221845,0.351812,0.203642,00:47
6,3.156742,2.513036,0.343037,0.458063,0.329531,00:47
7,3.072562,2.397222,0.354562,0.481183,0.350624,00:45
8,2.971532,2.361155,0.394428,0.50337,0.385113,00:46
9,2.924478,2.32838,0.395015,0.498838,0.393614,00:47


# AvgAttnPooling2d SansFFN Act

In [None]:
class AttentionPool2d(nn.Module):
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm = norm(ni)
        self.act = act_cls()
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        return self.act(self.norm(self.pool(x).flatten(1) + self.attn(x, self.cls_q)))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), nn.Linear(ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
    return head

In [None]:
train('XResNeXt50S AvgAttnPool SansFFN Act', partial(xeca_resnext50s, custom_head=AvgAttnPoolHead), lr=3e-3)

epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.727144,3.224067,0.128442,0.235333,0.11237,00:46
1,3.50639,2.976424,0.177693,0.318005,0.171378,00:45
2,3.346684,2.83365,0.226195,0.348046,0.221429,00:47
3,3.203139,2.667416,0.308606,0.400507,0.286896,00:47
4,3.01905,2.480593,0.36286,0.467323,0.358315,00:47
5,2.917349,2.452967,0.346727,0.472527,0.336422,00:47
6,2.79472,2.360375,0.405341,0.498137,0.402519,00:45
7,2.715605,1.991845,0.537433,0.629827,0.541718,00:46
8,2.561709,1.902613,0.593279,0.667205,0.593038,00:48
9,2.539161,1.844825,0.591841,0.682022,0.603762,00:47


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.724814,3.263175,0.122637,0.216532,0.099032,00:48
1,3.541528,3.109088,0.192549,0.288569,0.17567,00:48
2,3.409772,3.036436,0.193417,0.27038,0.190094,00:47
3,3.325612,2.938294,0.221294,0.324032,0.201672,00:48
4,3.179899,2.620766,0.308284,0.399346,0.297975,00:44
5,3.019681,2.450175,0.371513,0.482148,0.362113,00:47
6,2.91545,2.303976,0.384342,0.517211,0.395164,00:47
7,2.809995,2.137025,0.486886,0.585877,0.49667,00:48
8,2.766986,2.10009,0.478612,0.577381,0.482919,00:48
9,2.674103,2.074487,0.532796,0.604285,0.524231,00:48


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.705876,3.176106,0.154499,0.254753,0.132199,00:47
1,3.545747,3.1276,0.189187,0.286401,0.156083,00:47
2,3.385763,2.908806,0.213131,0.329145,0.196749,00:47
3,3.205624,2.808205,0.279464,0.370884,0.268626,00:47
4,3.138513,2.763533,0.27702,0.385278,0.257131,00:48
5,2.951607,2.695894,0.296,0.393045,0.277574,00:47
6,2.87036,2.193095,0.450686,0.567205,0.455935,00:48
7,2.789492,2.028685,0.513711,0.616967,0.530873,00:48
8,2.662072,2.003366,0.519385,0.624391,0.523754,00:47
9,2.634128,2.019936,0.514601,0.609876,0.52712,00:46


# AvgAttnConcatPooling2d SansFFN

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnConcatPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm1 = norm(ni)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm2 = norm(ni)
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        return torch.cat([self.norm2(self.pool(x).flatten(1)), self.norm1(self.attn(x, self.cls_q))], dim=1)

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnConcatPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnConcatPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), nn.Linear(2*ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
    return head

In [None]:
train('XResNeXt50S AvgAttnConcatPool SansFFN', partial(xeca_resnext50s, custom_head=AvgAttnConcatPoolHead), lr=1e-3)

epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.707652,3.13109,0.186142,0.281935,0.163482,00:52
1,3.461668,2.851134,0.255681,0.369298,0.254748,00:47
2,3.303868,2.777519,0.294392,0.386213,0.297582,00:46
3,3.190584,2.62563,0.359521,0.443665,0.34294,00:47
4,3.05895,2.570595,0.37866,0.44595,0.3798,00:47
5,2.94717,2.445122,0.390674,0.493729,0.390014,00:47
6,2.887095,2.237438,0.472251,0.563436,0.479638,00:47
7,2.78794,2.134237,0.521544,0.589007,0.521263,00:47
8,2.698291,2.115023,0.54123,0.600311,0.53452,00:46
9,2.656405,2.047428,0.566205,0.630445,0.571932,00:47


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.705544,3.128618,0.182738,0.287074,0.168324,00:48
1,3.489909,2.939498,0.256528,0.344917,0.254397,00:47
2,3.305806,2.783493,0.313164,0.391302,0.297111,00:48
3,3.180682,2.697955,0.346254,0.42158,0.343265,00:48
4,3.059373,2.522452,0.393646,0.471782,0.382245,00:48
5,2.916565,2.319428,0.453208,0.527713,0.451762,00:48
6,2.800817,2.172289,0.498576,0.583902,0.509833,00:45
7,2.732074,2.092673,0.533445,0.606926,0.548002,00:47
8,2.704672,2.022696,0.540977,0.624849,0.555428,00:48
9,2.63577,2.029439,0.610475,0.643866,0.612032,00:48


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.712416,3.13028,0.190356,0.276907,0.184139,00:48
1,3.467668,3.003263,0.234793,0.320466,0.222597,00:48
2,3.272699,2.744234,0.319837,0.390896,0.29386,00:48
3,3.095148,2.585558,0.357544,0.446114,0.362968,00:45
4,3.028565,2.473835,0.425327,0.484519,0.403661,00:47
5,2.867669,2.397583,0.42425,0.499428,0.433673,00:48
6,2.80164,2.186249,0.49232,0.572777,0.498556,00:48
7,2.749241,2.073919,0.546803,0.609688,0.563635,00:48
8,2.629795,1.983298,0.57854,0.635103,0.588683,00:48
9,2.611807,1.973777,0.601982,0.640666,0.607575,00:47


# AvgAttnConcatPooling2d Sandwich

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnConcatPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm1 = norm(ni)
        self.norm2 = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        self.norm3 = norm(ni)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm4 = norm(ni)
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        a = self.cls_q + self.norm1(self.attn(x, self.cls_q))
        a = a + self.ffn(self.norm2(a))
        return torch.cat([self.norm4(self.pool(x).flatten(1)), self.norm3(a)], dim=1)

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnConcatPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnConcatPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), nn.Linear(2*ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
    return head

In [None]:
train('XResNeXt50S AvgAttnConcatPool Sandwich', partial(xeca_resnext50s, custom_head=AvgAttnConcatPoolHead), lr=2e-3)

epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.769979,3.230634,0.167111,0.247273,0.150634,00:53
1,3.58361,3.007699,0.221425,0.315229,0.21336,00:49
2,3.42674,2.959347,0.261177,0.324164,0.249145,00:48
3,3.34198,2.897947,0.277429,0.351575,0.26592,00:49
4,3.233409,2.725884,0.31686,0.404361,0.310648,00:49
5,3.107635,2.720646,0.305318,0.39742,0.299545,00:47
6,3.083304,2.622248,0.341751,0.432612,0.324462,00:49
7,3.013007,2.372486,0.427791,0.509718,0.428968,00:49
8,2.935696,2.385368,0.423399,0.501324,0.410507,00:48
9,2.890918,2.279433,0.462729,0.543739,0.465348,00:48


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.797383,3.329143,0.147197,0.225462,0.135632,00:50
1,3.557944,3.058341,0.198868,0.294455,0.194384,00:46
2,3.526908,3.143579,0.230255,0.282881,0.19874,00:48
3,3.613389,3.494285,0.126274,0.198042,0.105002,00:50
4,3.700536,3.364957,0.139675,0.197343,0.123993,00:50
5,3.836007,3.565285,0.08634,0.162198,0.075368,00:50
6,3.770156,3.41975,0.119092,0.196356,0.104418,00:49
7,3.6921,3.269138,0.130868,0.228642,0.118043,00:48
8,3.615143,3.396312,0.124129,0.208202,0.107933,00:49
9,3.541628,3.027871,0.211955,0.298568,0.202877,00:50


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.80249,3.240008,0.156091,0.237269,0.14236,00:50
1,3.673523,3.501035,0.136994,0.184378,0.114588,00:48
2,3.757553,3.469865,0.124236,0.190305,0.104061,00:48
3,3.698674,3.325392,0.166633,0.224903,0.129322,00:50
4,3.665573,3.334323,0.136503,0.220509,0.112172,00:50
5,3.637837,3.438549,0.100692,0.18554,0.086132,00:50
6,3.56082,3.041651,0.21965,0.298534,0.198681,00:50
7,3.454649,2.923695,0.253811,0.324081,0.246071,00:48
8,3.345901,2.915993,0.24175,0.332594,0.243845,00:50
9,3.311073,2.84842,0.30336,0.358164,0.273363,00:50


# AvgAttnConcatPooling2d Sandwich Act

In [None]:
class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)
        if isinstance(self.norm, (nn.BatchNorm1d, nn.BatchNorm2d)):
            self.norm_forward = self.bn_norm_flat
        else:
            self.norm_forward = self.norm_flat

    def bn_norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2)).transpose(1,2)

    def norm_flat(self, x:Tensor):
        return self.norm(x.flatten(2).transpose(1,2))
    
    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm_forward(x)
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, -1, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, C)
        return self.proj(x)
 
    
class AvgAttnConcatPooling2d(nn.Module):
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.cls_q = nn.Parameter(torch.zeros([1,ni]))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm1 = norm(ni)
        self.norm2 = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            norm(int(ni*ffn_expand)),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        self.norm3 = norm(ni)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.norm4 = norm(ni)
        self.act = act_cls()
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        a = self.cls_q + self.norm1(self.attn(x, self.cls_q))
        a = a + self.ffn(self.norm2(a))
        return self.act(torch.cat([self.norm4(self.pool(x).flatten(1)), self.norm3(a)], dim=1))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

In [None]:
def AvgAttnConcatPoolHead(ni, n_out, norm=nn.LayerNorm, ffn_expand=3, **kwargs):
    head = [AvgAttnConcatPooling2d(ni, norm=norm, ffn_expand=ffn_expand, **kwargs), nn.Linear(2*ni, n_out)]
    with torch.no_grad():
        head[0]._init_weights(head[1])
    return head

In [None]:
train('XResNeXt50S AvgAttnConcatPool Sandwich Act', partial(xeca_resnext50s, custom_head=AvgAttnConcatPoolHead), lr=4e-3)

epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.957834,3.904493,0.045493,0.089571,0.024212,00:43
1,3.811099,3.739213,0.05381,0.139383,0.042106,00:39
2,3.633473,3.219637,0.134388,0.226387,0.114541,00:39
3,3.50922,3.31318,0.144331,0.212916,0.118862,00:39
4,3.353218,2.818501,0.268146,0.364152,0.249509,00:39
5,3.168801,2.651315,0.30376,0.415418,0.304272,00:39
6,3.039989,2.499519,0.363672,0.459276,0.359966,00:39
7,2.890137,2.303134,0.435728,0.526297,0.423356,00:39
8,2.721788,2.37219,0.464662,0.496042,0.451075,00:39
9,2.631593,2.09113,0.521259,0.599322,0.524443,00:39


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.81437,3.292209,0.149292,0.234713,0.130626,00:39
1,3.621014,3.088675,0.18219,0.298327,0.177907,00:39
2,3.488069,3.038923,0.216397,0.298194,0.194282,00:40
3,3.353477,2.803133,0.269019,0.368109,0.252404,00:40
4,3.204453,2.618281,0.331061,0.425963,0.324085,00:40
5,3.003161,2.387838,0.40018,0.496896,0.393588,00:39
6,2.836666,2.282184,0.418135,0.532971,0.422933,00:40
7,2.682908,2.290159,0.426787,0.519691,0.438015,00:40
8,2.620272,2.084296,0.508304,0.601991,0.525089,00:40
9,2.492067,1.80744,0.654985,0.701835,0.653134,00:40


epoch,train_loss,valid_loss,balanced_accuracy_score,matthews_corrcoef,f1_score,time
0,3.803858,3.353991,0.113991,0.209365,0.109143,00:39
1,3.585282,3.137198,0.219072,0.281566,0.183913,00:39
2,3.390608,2.895007,0.23455,0.343156,0.222605,00:40
3,3.209651,2.742712,0.296609,0.389848,0.285701,00:40
4,3.085811,2.723303,0.323473,0.396492,0.299856,00:40
5,2.868383,2.397392,0.416065,0.491373,0.409788,00:40
6,2.774595,2.09107,0.522658,0.605469,0.536488,00:40
7,2.656533,1.96366,0.568011,0.646423,0.593274,00:40
8,2.501331,1.88503,0.612376,0.677192,0.617468,00:40
9,2.431914,1.779635,0.662062,0.714387,0.669555,00:39
