In [1]:
import numpy as np
import pandas as pd
import random
from glob import glob
import os, shutil
from tqdm import tqdm
tqdm.pandas()
import time
import copy
import joblib
from collections import defaultdict
import gc
from IPython import display as ipd
import math
# visualization
import cv2
from glob import glob
# Sklearn
from sklearn.model_selection import StratifiedKFold, KFold, StratifiedGroupKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, confusion_matrix, roc_curve
import timm
# PyTorch 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
from torch.optim.swa_utils import AveragedModel, SWALR
from transformers import get_cosine_schedule_with_warmup
from collections import defaultdict
# import matplotlib.pyplot as plt
# Albumentations for augmentations
import albumentations as A
import albumentations
import albumentations as albu
from albumentations.pytorch import ToTensorV2
from datetime import datetime
import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CFG:
    seed = 1
    model_name = "tf_efficientnetv2_b2"
    train_bs = 16
    valid_bs = 64
    image_size = 1024
    epochs = 25
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(CFG.device)

cuda:0


In [3]:
df = pd.read_csv("train_10folds.csv")
df.head()

Unnamed: 0,site_id,patient_id,image_id,laterality,view,age,cancer,biopsy,invasive,BIRADS,implant,density,machine_id,difficult_negative_case,split,fold
0,2,10006,462822612,L,CC,61.0,0,0,0,,0,,29,False,10006_L,0
1,2,10006,1459541791,L,MLO,61.0,0,0,0,,0,,29,False,10006_L,0
2,2,10006,1864590858,R,MLO,61.0,0,0,0,,0,,29,False,10006_R,2
3,2,10006,1874946579,R,CC,61.0,0,0,0,,0,,29,False,10006_R,2
4,2,10011,220375232,L,CC,55.0,0,0,0,0.0,0,,21,True,10011_L,5


In [4]:
is_hol = df['cancer'] == 1
df_try = df[is_hol]
df1 = df.append([df_try]*1,ignore_index=True)
print(len(df1))

55864


In [5]:
from functools import partial

import torch
import torch.utils.checkpoint as checkpoint
from einops import rearrange
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
from torch import nn

NORM_EPS = 1e-5

def merge_pre_bn(module, pre_bn_1, pre_bn_2=None):
    """ Merge pre BN to reduce inference runtime.
    """
    weight = module.weight.data
    if module.bias is None:
        zeros = torch.zeros(module.out_channels, device=weight.device).type(weight.type())
        module.bias = nn.Parameter(zeros)
    bias = module.bias.data
    if pre_bn_2 is None:
        assert pre_bn_1.track_running_stats is True, "Unsupport bn_module.track_running_stats is False"
        assert pre_bn_1.affine is True, "Unsupport bn_module.affine is False"

        scale_invstd = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
        extra_weight = scale_invstd * pre_bn_1.weight
        extra_bias = pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd
    else:
        assert pre_bn_1.track_running_stats is True, "Unsupport bn_module.track_running_stats is False"
        assert pre_bn_1.affine is True, "Unsupport bn_module.affine is False"

        assert pre_bn_2.track_running_stats is True, "Unsupport bn_module.track_running_stats is False"
        assert pre_bn_2.affine is True, "Unsupport bn_module.affine is False"

        scale_invstd_1 = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
        scale_invstd_2 = pre_bn_2.running_var.add(pre_bn_2.eps).pow(-0.5)

        extra_weight = scale_invstd_1 * pre_bn_1.weight * scale_invstd_2 * pre_bn_2.weight
        extra_bias = scale_invstd_2 * pre_bn_2.weight *(pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd_1 - pre_bn_2.running_mean) + pre_bn_2.bias

    if isinstance(module, nn.Linear):
        extra_bias = weight @ extra_bias
        weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight))
    elif isinstance(module, nn.Conv2d):
        assert weight.shape[2] == 1 and weight.shape[3] == 1
        weight = weight.reshape(weight.shape[0], weight.shape[1])
        extra_bias = weight @ extra_bias
        weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight))
        weight = weight.reshape(weight.shape[0], weight.shape[1], 1, 1)
    bias.add_(extra_bias)

    module.weight.data = weight
    module.bias.data = bias

class ConvBNReLU(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride,
            groups=1):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
                              padding=1, groups=groups, bias=False)
        self.norm = nn.BatchNorm2d(out_channels, eps=NORM_EPS)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x


def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class PatchEmbed(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 stride=1):
        super(PatchEmbed, self).__init__()
        norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
        if stride == 2:
            self.avgpool = nn.AvgPool2d((2, 2), stride=2, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif in_channels != out_channels:
            self.avgpool = nn.Identity()
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        else:
            self.avgpool = nn.Identity()
            self.conv = nn.Identity()
            self.norm = nn.Identity()

    def forward(self, x):
        return self.norm(self.conv(self.avgpool(x)))


class MHCA(nn.Module):
    """
    Multi-Head Convolutional Attention
    """
    def __init__(self, out_channels, head_dim):
        super(MHCA, self).__init__()
        norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
        self.group_conv3x3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1,
                                       padding=1, groups=out_channels // head_dim, bias=False)
        self.norm = norm_layer(out_channels)
        self.act = nn.ReLU(inplace=True)
        self.projection = nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False)

    def forward(self, x):
        out = self.group_conv3x3(x)
        out = self.norm(out)
        out = self.act(out)
        out = self.projection(out)
        return out


class Mlp(nn.Module):
    def __init__(self, in_features, out_features=None, mlp_ratio=None, drop=0., bias=True):
        super().__init__()
        out_features = out_features or in_features
        hidden_dim = _make_divisible(in_features * mlp_ratio, 32)
        self.conv1 = nn.Conv2d(in_features, hidden_dim, kernel_size=1, bias=bias)
        self.act = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(hidden_dim, out_features, kernel_size=1, bias=bias)
        self.drop = nn.Dropout(drop)

    def merge_bn(self, pre_norm):
        merge_pre_bn(self.conv1, pre_norm)

    def forward(self, x):
        x = self.conv1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.conv2(x)
        x = self.drop(x)
        return x


class NCB(nn.Module):
    """
    Next Convolution Block
    """
    def __init__(self, in_channels, out_channels, stride=1, path_dropout=0,
                 drop=0, head_dim=32, mlp_ratio=3):
        super(NCB, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
        assert out_channels % head_dim == 0

        self.patch_embed = PatchEmbed(in_channels, out_channels, stride)
        self.mhca = MHCA(out_channels, head_dim)
        self.attention_path_dropout = DropPath(path_dropout)

        self.norm = norm_layer(out_channels)
        self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop, bias=True)
        self.mlp_path_dropout = DropPath(path_dropout)
        self.is_bn_merged = False

    def merge_bn(self):
        if not self.is_bn_merged:
            self.mlp.merge_bn(self.norm)
            self.is_bn_merged = True

    def forward(self, x):
        x = self.patch_embed(x)
        x = x + self.attention_path_dropout(self.mhca(x))
        if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
            out = self.norm(x)
        else:
            out = x
        x = x + self.mlp_path_dropout(self.mlp(out))
        return x


class E_MHSA(nn.Module):
    """
    Efficient Multi-Head Self Attention
    """
    def __init__(self, dim, out_dim=None, head_dim=32, qkv_bias=True, qk_scale=None,
                 attn_drop=0, proj_drop=0., sr_ratio=1):
        super().__init__()
        self.dim = dim
        self.out_dim = out_dim if out_dim is not None else dim
        self.num_heads = self.dim // head_dim
        self.scale = qk_scale or head_dim ** -0.5
        self.q = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.k = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.v = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.proj = nn.Linear(self.dim, self.out_dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        self.N_ratio = sr_ratio ** 2
        if sr_ratio > 1:
            self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio)
            self.norm = nn.BatchNorm1d(dim, eps=NORM_EPS)
        self.is_bn_merged = False

    def merge_bn(self, pre_bn):
        merge_pre_bn(self.q, pre_bn)
        if self.sr_ratio > 1:
            merge_pre_bn(self.k, pre_bn, self.norm)
            merge_pre_bn(self.v, pre_bn, self.norm)
        else:
            merge_pre_bn(self.k, pre_bn)
            merge_pre_bn(self.v, pre_bn)
        self.is_bn_merged = True

    def forward(self, x):
        B, N, C = x.shape
        q = self.q(x)
        q = q.reshape(B, N, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)

        if self.sr_ratio > 1:
            x_ = x.transpose(1, 2)
            x_ = self.sr(x_)
            if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
                x_ = self.norm(x_)
            x_ = x_.transpose(1, 2)
            k = self.k(x_)
            k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
            v = self.v(x_)
            v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        else:
            k = self.k(x)
            k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
            v = self.v(x)
            v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        attn = (q @ k) * self.scale

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class NTB(nn.Module):
    """
    Next Transformer Block
    """
    def __init__(
            self, in_channels, out_channels, path_dropout, stride=1, sr_ratio=1,
            mlp_ratio=2, head_dim=32, mix_block_ratio=0.75, attn_drop=0, drop=0,
    ):
        super(NTB, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.mix_block_ratio = mix_block_ratio
        norm_func = partial(nn.BatchNorm2d, eps=NORM_EPS)

        self.mhsa_out_channels = _make_divisible(int(out_channels * mix_block_ratio), 32)
        self.mhca_out_channels = out_channels - self.mhsa_out_channels

        self.patch_embed = PatchEmbed(in_channels, self.mhsa_out_channels, stride)
        self.norm1 = norm_func(self.mhsa_out_channels)
        self.e_mhsa = E_MHSA(self.mhsa_out_channels, head_dim=head_dim, sr_ratio=sr_ratio,
                             attn_drop=attn_drop, proj_drop=drop)
        self.mhsa_path_dropout = DropPath(path_dropout * mix_block_ratio)

        self.projection = PatchEmbed(self.mhsa_out_channels, self.mhca_out_channels, stride=1)
        self.mhca = MHCA(self.mhca_out_channels, head_dim=head_dim)
        self.mhca_path_dropout = DropPath(path_dropout * (1 - mix_block_ratio))

        self.norm2 = norm_func(out_channels)
        self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop)
        self.mlp_path_dropout = DropPath(path_dropout)

        self.is_bn_merged = False

    def merge_bn(self):
        if not self.is_bn_merged:
            self.e_mhsa.merge_bn(self.norm1)
            self.mlp.merge_bn(self.norm2)
            self.is_bn_merged = True

    def forward(self, x):
        x = self.patch_embed(x)
        B, C, H, W = x.shape
        if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
            out = self.norm1(x)
        else:
            out = x
        out = rearrange(out, "b c h w -> b (h w) c")  # b n c
        out = self.mhsa_path_dropout(self.e_mhsa(out))
        x = x + rearrange(out, "b (h w) c -> b c h w", h=H)

        out = self.projection(x)
        out = out + self.mhca_path_dropout(self.mhca(out))
        x = torch.cat([x, out], dim=1)

        if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
            out = self.norm2(x)
        else:
            out = x
        x = x + self.mlp_path_dropout(self.mlp(out))
        return x


class NextViT(nn.Module):
    def __init__(self, stem_chs, depths, path_dropout, attn_drop=0, drop=0, num_classes=1000,
                 strides=[1, 2, 2, 2], sr_ratios=[8, 4, 2, 1], head_dim=32, mix_block_ratio=0.75,
                 use_checkpoint=False):
        super(NextViT, self).__init__()
        self.use_checkpoint = use_checkpoint

        self.stage_out_channels = [[96] * (depths[0]),
                                   [192] * (depths[1] - 1) + [256],
                                   [384, 384, 384, 384, 512] * (depths[2] // 5),
                                   [768] * (depths[3] - 1) + [1024]]

        # Next Hybrid Strategy
        self.stage_block_types = [[NCB] * depths[0],
                                  [NCB] * (depths[1] - 1) + [NTB],
                                  [NCB, NCB, NCB, NCB, NTB] * (depths[2] // 5),
                                  [NCB] * (depths[3] - 1) + [NTB]]

        self.stem = nn.Sequential(
            ConvBNReLU(3, stem_chs[0], kernel_size=3, stride=2),
            ConvBNReLU(stem_chs[0], stem_chs[1], kernel_size=3, stride=1),
            ConvBNReLU(stem_chs[1], stem_chs[2], kernel_size=3, stride=1),
            ConvBNReLU(stem_chs[2], stem_chs[2], kernel_size=3, stride=2),
        )
        input_channel = stem_chs[-1]
        features = []
        idx = 0
        dpr = [x.item() for x in torch.linspace(0, path_dropout, sum(depths))]  # stochastic depth decay rule
        for stage_id in range(len(depths)):
            numrepeat = depths[stage_id]
            output_channels = self.stage_out_channels[stage_id]
            block_types = self.stage_block_types[stage_id]
            for block_id in range(numrepeat):
                if strides[stage_id] == 2 and block_id == 0:
                    stride = 2
                else:
                    stride = 1
                output_channel = output_channels[block_id]
                block_type = block_types[block_id]
                if block_type is NCB:
                    layer = NCB(input_channel, output_channel, stride=stride, path_dropout=dpr[idx + block_id],
                                drop=drop, head_dim=head_dim)
                    features.append(layer)
                elif block_type is NTB:
                    layer = NTB(input_channel, output_channel, path_dropout=dpr[idx + block_id], stride=stride,
                                sr_ratio=sr_ratios[stage_id], head_dim=head_dim, mix_block_ratio=mix_block_ratio,
                                attn_drop=attn_drop, drop=drop)
                    features.append(layer)
                input_channel = output_channel
            idx += numrepeat
        self.features = nn.Sequential(*features)

        self.norm = nn.BatchNorm2d(output_channel, eps=NORM_EPS)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.proj_head = nn.Sequential(
            nn.Linear(output_channel, num_classes),
        )

        self.stage_out_idx = [sum(depths[:idx + 1]) - 1 for idx in range(len(depths))]
        print('initialize_weights...')
        self._initialize_weights()

    def merge_bn(self):
        self.eval()
        for idx, module in self.named_modules():
            if isinstance(module, NCB) or isinstance(module, NTB):
                module.merge_bn()

    def _initialize_weights(self):
        for n, m in self.named_modules():
            if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm, nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv2d):
                trunc_normal_(m.weight, std=.02)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.stem(x)
        for idx, layer in enumerate(self.features):
            if self.use_checkpoint:
                x = checkpoint.checkpoint(layer, x)
            else:
                x = layer(x)
        x = self.norm(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.proj_head(x)
        return x


@register_model
def nextvit_small(pretrained=False, pretrained_cfg=None, **kwargs):
    model = NextViT(stem_chs=[64, 32, 64], depths=[3, 4, 10, 3], path_dropout=0.1, **kwargs)
    return model


@register_model
def nextvit_base(pretrained=False, pretrained_cfg=None, **kwargs):
    model = NextViT(stem_chs=[64, 32, 64], depths=[3, 4, 20, 3], path_dropout=0.2, **kwargs)
    return model


@register_model
def nextvit_large(pretrained=False, pretrained_cfg=None, **kwargs):
    model = NextViT(stem_chs=[64, 32, 64], depths=[3, 4, 30, 3], path_dropout=0.2, **kwargs)
    return model



In [6]:
def init_logger(log_file='train1.log'):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger()
now = datetime.now()
datetime_now = now.strftime("%m/%d/%Y, %H:%M:%S")
LOGGER.info(f"Date :{datetime_now}")

Date :01/30/2023, 14:08:30


In [7]:
from albumentations import DualTransform
image_size = 1024
def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
    h, w = img.shape[:2]
    if max(w, h) == size:
        return img
    if w > h:
        scale = size / w
        h = h * scale
        w = size
    else:
        scale = size / h
        w = w * scale
        h = size
    interpolation = interpolation_up if scale > 1 else interpolation_down
    resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
    return resized


class IsotropicResize(DualTransform):
    def __init__(self, max_side, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC,
                 always_apply=False, p=1):
        super(IsotropicResize, self).__init__(always_apply, p)
        self.max_side = max_side
        self.interpolation_down = interpolation_down
        self.interpolation_up = interpolation_up

    def apply(self, img, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, **params):
        return isotropically_resize_image(img, size=self.max_side, interpolation_down=interpolation_down,
                                          interpolation_up=interpolation_up)

    def apply_to_mask(self, img, **params):
        return self.apply(img, interpolation_down=cv2.INTER_NEAREST, interpolation_up=cv2.INTER_NEAREST, **params)

    def get_transform_init_args_names(self):
        return ("max_side", "interpolation_down", "interpolation_up")
    
data_transforms = {
    "train": A.Compose([
        # A.Resize(image_size, image_size),
        # IsotropicResize(max_side = image_size),
        # A.PadIfNeeded(min_height=image_size, min_width=image_size, border_mode=cv2.BORDER_CONSTANT),
        # A.RandomBrightnessContrast(),
        # A.VerticalFlip(p=0.5),   
        # A.ColorJitter(),
        # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
        # A.HorizontalFlip(p=0.5),
        # A.Cutout(max_h_size=int(image_size * 0.1), max_w_size=int(image_size * 0.1), num_holes=5, p=0.5),
        A.VerticalFlip(p=0.5),   
        A.ColorJitter(),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
        A.HorizontalFlip(p=0.5),
        A.Cutout(max_h_size=102, max_w_size=102, num_holes=5, p=0.5),
        # A.CLAHE(p=1.0),
        # albumentations.HorizontalFlip(p=0.5),
        # # albumentations.VerticalFlip(p=0.5),
        # albumentations.RandomBrightness(limit=0.2, p=0.75),
        # albumentations.RandomContrast(limit=0.2, p=0.75),

        # albumentations.OneOf([
        #     albumentations.OpticalDistortion(distort_limit=1.),
        #     albumentations.GridDistortion(num_steps=5, distort_limit=1.),
        # ], p=0.75),

        # albumentations.HueSaturationValue(hue_shift_limit=40, sat_shift_limit=40, val_shift_limit=0, p=0.75),
        # albumentations.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.3, rotate_limit=30, border_mode=0, p=0.75),
        # A.Cutout(always_apply=False, p=0.5, num_holes=1, max_h_size=409, max_w_size=409),
        # A.OneOf([ 
        # A.OpticalDistortion(distort_limit=1.0), 
        # A.GridDistortion(num_steps=5, distort_limit=1.),
        # A.ElasticTransform(alpha=3), ], p=0.2),
        # A.OneOf([
        #     # A.GaussNoise(var_limit=[10, 50]),
        #     A.GaussianBlur(),
        #     A.MotionBlur(),
        #     A.MedianBlur(), ], p=0.2),
        # A.OneOf([
        #     A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
        #     A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, p=1.0),
        #     A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
        # ], p=0.25),
        # A.CoarseDropout(max_holes=8, max_height=image_size//20, max_width=image_size//20,
        #                  min_holes=5, fill_value=0, mask_fill_value=0, p=0.5),
        # A.Normalize(mean=0, std=1),
        ToTensorV2(),], p=1.0),
    
    "valid": A.Compose([
        # IsotropicResize(max_side =image_size),
        # A.PadIfNeeded(min_height=image_size, min_width=image_size, border_mode=cv2.BORDER_CONSTANT),
        # A.Normalize(mean=0, std=1),
        # A.Resize(image_size, image_size),
        ToTensorV2(),
        ], p=1.0)
}

LOGGER.info(f"train transform{data_transforms['train']}")


train transformCompose([
  VerticalFlip(always_apply=False, p=0.5),
  ColorJitter(always_apply=False, p=0.5, brightness=[0.8, 1.2], contrast=[0.8, 1.2], saturation=[0.8, 1.2], hue=[-0.2, 0.2]),
  ShiftScaleRotate(always_apply=False, p=0.5, shift_limit_x=(-0.0625, 0.0625), shift_limit_y=(-0.0625, 0.0625), scale_limit=(-0.050000000000000044, 0.050000000000000044), rotate_limit=(-10, 10), interpolation=1, border_mode=4, value=None, mask_value=None, rotate_method='largest_box'),
  HorizontalFlip(always_apply=False, p=0.5),
  Cutout(always_apply=False, p=0.5, num_holes=5, max_h_size=102, max_w_size=102),
  ToTensorV2(always_apply=True, p=1.0, transpose_mask=False),
], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={})


In [8]:
# from albumentations import DualTransform
# image_size = 1024
# def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
#     h, w = img.shape[:2]
#     if max(w, h) == size:
#         return img
#     if w > h:
#         scale = size / w
#         h = h * scale
#         w = size
#     else:
#         scale = size / h
#         w = w * scale
#         h = size
#     interpolation = interpolation_up if scale > 1 else interpolation_down
#     resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
#     return resized


# class IsotropicResize(DualTransform):
#     def __init__(self, max_side, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC,
#                  always_apply=False, p=1):
#         super(IsotropicResize, self).__init__(always_apply, p)
#         self.max_side = max_side
#         self.interpolation_down = interpolation_down
#         self.interpolation_up = interpolation_up

#     def apply(self, img, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, **params):
#         return isotropically_resize_image(img, size=self.max_side, interpolation_down=interpolation_down,
#                                           interpolation_up=interpolation_up)

#     def apply_to_mask(self, img, **params):
#         return self.apply(img, interpolation_down=cv2.INTER_NEAREST, interpolation_up=cv2.INTER_NEAREST, **params)

#     def get_transform_init_args_names(self):
#         return ("max_side", "interpolation_down", "interpolation_up")
    
# data_transforms = {
#     "train": A.Compose([
# #         A.Resize(image_size, image_size),
#         # IsotropicResize(max_side = image_size),
#        A.PadIfNeeded(min_width=image_size, border_mode=cv2.BORDER_CONSTANT),
#         albumentations.HorizontalFlip(p=0.5),
#         # albumentations.VerticalFlip(p=0.5),
#         albumentations.RandomBrightness(limit=0.2, p=0.75),
#         albumentations.RandomContrast(limit=0.2, p=0.75),

#         albumentations.OneOf([
#             albumentations.OpticalDistortion(distort_limit=1.),
#             albumentations.GridDistortion(num_steps=5, distort_limit=1.),
#         ], p=0.75),

#         albumentations.HueSaturationValue(hue_shift_limit=40, sat_shift_limit=40, val_shift_limit=0, p=0.75),
#         albumentations.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.3, rotate_limit=30, border_mode=0, p=0.75),
#         A.Cutout(always_apply=False, p=0.5, num_holes=1, max_h_size=409, max_w_size=409),
#         # A.RandomBrightnessContrast(),
#         # A.VerticalFlip(p=0.5),   
#         A.ColorJitter(p = 0.7),
#         # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
#         # A.HorizontalFlip(p=0.5),
#         # A.Cutout(max_h_size=int(image_size * 0.1), max_w_size=int(image_size * 0.1), num_holes=5, p=0.5),
#         # albumentations.RandomBrightness(limit=0.2, p=0.75),
#         # albumentations.RandomContrast(limit=0.2, p=0.75),

#         # albumentations.OneOf([
#         #     albumentations.OpticalDistortion(distort_limit=1.),
#         #     albumentations.GridDistortion(num_steps=5, distort_limit=1.),
#         # ], p=0.75),

#         # albumentations.HueSaturationValue(hue_shift_limit=40, sat_shift_limit=40, val_shift_limit=0, p=0.75),
#         # albumentations.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.3, rotate_limit=30, border_mode=0, p=0.75),
#         # A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=10, val_shift_limit=10, p=0.7),
#         # A.RandomBrightnessContrast(brightness_limit=(-0.2,0.2), contrast_limit=(-0.2, 0.2), p=0.7),
#         # A.CLAHE(p=0.5),
#         # albumentations.OneOf([
#         # albumentations.OpticalDistortion(distort_limit=1.),
#         # albumentations.GridDistortion(num_steps=5, distort_limit=1.),
#         # ], p=0.75),
#         # A.OneOf([
#         # A.GaussianBlur(),
#         # A.MotionBlur(),
#         # A.MedianBlur(), ], p=0.5),
#         # A.IAASharpen(p = 0.2),
#         # A.JpegCompression(p=0.2),
#         # A.Downscale(scale_min=0.5, scale_max=0.75),
#         # A.OneOf([ A.JpegCompression(), A.Downscale(scale_min=0.1, scale_max=0.15), ], p=0.2), 
#         # A.IAAPiecewiseAffine(),
# #         A.OneOf([ 
# #         A.OpticalDistortion(distort_limit=1.0), 
# #         A.GridDistortion(num_steps=5, distort_limit=1.),
# #         A.ElasticTransform(alpha=3), ], p=0.2),
# #         A.OneOf([
# #             A.GaussNoise(var_limit=[10, 50]),
# #             A.GaussianBlur(),
# #             A.MotionBlur(),
# #             A.MedianBlur(), ], p=0.2),
#         # A.OneOf([
#         #     A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
#         #     A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, p=1.0),
#         #     A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
#         # ], p=0.25),
#         # A.CoarseDropout(max_holes=8, max_height=image_size//20, max_width=image_size//20,
#         #                  min_holes=5, fill_value=0, mask_fill_value=0, p=0.5),
#         # A.Normalize(mean=0, std=1),
#         ToTensorV2(),], p=1.0),
    
#     "valid": A.Compose([
#         # IsotropicResize(max_side = image_size),
#         A.PadIfNeeded(min_height=image_size, min_width=image_size, border_mode=cv2.BORDER_CONSTANT),
#         # A.Normalize(mean=0, std=1),
# #         A.Resize(image_size, image_size),
#         ToTensorV2(),
#         ], p=1.0)
# }

# LOGGER.info(f"train transform{data_transforms['train']}")


In [9]:
def pad(array, target_shape):
    return np.pad(
        array,
        [(0, target_shape[i] - array.shape[i]) for i in range(len(array.shape))],
        "constant",
    )
    
def load_img(img_path):
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # image = pad(image, (1024, 800, 3))
        # img = img.reshape((*resize))
    return image
#     image = cv2.resize(image, (320, 320), cv2.INTER_NEAREST)
#     image = image.astype(np.float32)
#     mx = np.max(image)
#     if mx:
#         image/=mx
#     image = image /255.0
    
    return image
class BreastDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.transforms = transforms
        
    def __getitem__(self, index):
        row = self.df.iloc[index]
        img_path = f"flip/{row.patient_id}_{row.image_id}.png"
        img = load_img(img_path)
        label = row['cancer']
        # img = np.transpose(img, (2, 0, 1))
        data = self.transforms(image=img)
        img  = data['image']
        # img = img/255.0
        return torch.tensor(img).float(), torch.tensor(label).long()
        
    def __len__(self):
        return len(self.df)
    
fold0 = df[df['fold']==0]
train_dataset = BreastDataset(fold0, transforms = data_transforms['train'])
image, label = train_dataset[0]
print(image.shape, label)
print(image.max())

torch.Size([3, 1344, 840]) tensor(0)
tensor(253.)


In [10]:

# from pylab import rcParams

# f, axarr = plt.subplots(1,15, figsize = (20, 20))
# imgs = []
# for p in range(15):
#     img, label = train_dataset[p]
#     img = img.transpose(0, 1).transpose(1,2).cpu().numpy()
#     img = img.astype(np.uint8)
#     imgs.append(img)
#     axarr[p].imshow(img)


# Model

In [11]:
class Model(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        # ,drop_rate = 0.3, drop_path_rate = 0.2
        self.backbone = timm.create_model(model_name, pretrained=True,drop_rate = 0.3, drop_path_rate = 0.2)
        self.fc = nn.Linear(self.backbone.classifier.in_features,2)
        self.backbone.classifier = nn.Identity()
        self.dropout = nn.Dropout(0.5)
    def forward(self, x):
        x = self.backbone(x)
        x = self.fc(self.dropout(x))
        return x

class ModelNextVit(nn.Module):
    def __init__(self):
        super().__init__()
        # ,drop_rate = 0.3, drop_path_rate = 0.2
        self.checkpoint = torch.load('nextvit_small_in1k_384.pth')
        self.backbone = nextvit_small()
        self.backbone.load_state_dict(self.checkpoint['model'])
        self.backbone.proj_head = nn.Linear(1024, 2)

    def forward(self, x):
        x = self.backbone(x)
        return x

In [12]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer, required
import math

class AdamP(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
                        delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
        super(AdamP, self).__init__(params, defaults)

    def _channel_view(self, x):
        return x.view(x.size(0), -1)

    def _layer_view(self, x):
        return x.view(1, -1)

    def _cosine_similarity(self, x, y, eps, view_func):
        x = view_func(x)
        y = view_func(y)

        return F.cosine_similarity(x, y, dim=1, eps=eps).abs_()

    def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
        wd = 1
        expand_size = [-1] + [1] * (len(p.shape) - 1)
        for view_func in [self._channel_view, self._layer_view]:

            cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)

            if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
                p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
                perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
                wd = wd_ratio

                return perturb, wd

        return perturb, wd

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad.data
                beta1, beta2 = group['betas']
                nesterov = group['nesterov']

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                # Adam
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                state['step'] += 1
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']

                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
                step_size = group['lr'] / bias_correction1

                if nesterov:
                    perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
                else:
                    perturb = exp_avg / denom

                # Projection
                wd_ratio = 1
                if len(p.shape) > 1:
                    perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'])

                # Weight decay
                if group['weight_decay'] > 0:
                    p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio)

                # Step
                p.data.add_(perturb, alpha=-step_size)

        return loss

class SGDP(Optimizer):
    def __init__(self, params, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1):
        defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay,
                        nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
        super(SGDP, self).__init__(params, defaults)

    def _channel_view(self, x):
        return x.view(x.size(0), -1)

    def _layer_view(self, x):
        return x.view(1, -1)

    def _cosine_similarity(self, x, y, eps, view_func):
        x = view_func(x)
        y = view_func(y)

        return F.cosine_similarity(x, y, dim=1, eps=eps).abs_()

    def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
        wd = 1
        expand_size = [-1] + [1] * (len(p.shape) - 1)
        for view_func in [self._channel_view, self._layer_view]:

            cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)

            if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
                p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
                perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
                wd = wd_ratio

                return perturb, wd

        return perturb, wd

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['momentum'] = torch.zeros_like(p.data)

                # SGD
                buf = state['momentum']
                buf.mul_(momentum).add_(grad, alpha=1 - dampening)
                if nesterov:
                    d_p = grad + momentum * buf
                else:
                    d_p = buf

                # Projection
                wd_ratio = 1
                if len(p.shape) > 1:
                    d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps'])

                # Weight decay
                if group['weight_decay'] > 0:
                    p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))

                # Step
                p.data.add_(d_p, alpha=-group['lr'])

        return loss

class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    def step(self, closure=None):
        raise NotImplementedError("SAM doesn't work like the other optimizers, you should first call `first_step` and the `second_step`; see the documentation for more info.")

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

In [14]:
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau


class GradualWarmupScheduler(_LRScheduler):
    """ Gradually warm-up(increasing) learning rate in optimizer.
    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
        total_epoch: target learning rate is reached at total_epoch, gradually
        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
    """

    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        self.multiplier = multiplier
        if self.multiplier < 1.:
            raise ValueError('multiplier should be greater thant or equal to 1.')
        self.total_epoch = total_epoch
        self.after_scheduler = after_scheduler
        self.finished = False
        super(GradualWarmupScheduler, self).__init__(optimizer)

    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_last_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]

        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

    def step_ReduceLROnPlateau(self, metrics, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch if epoch != 0 else 1  # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
        if self.last_epoch <= self.total_epoch:
            warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
            for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
                param_group['lr'] = lr
        else:
            if epoch is None:
                self.after_scheduler.step(metrics, None)
            else:
                self.after_scheduler.step(metrics, epoch - self.total_epoch)

    def step(self, epoch=None, metrics=None):
        if type(self.after_scheduler) != ReduceLROnPlateau:
            if self.finished and self.after_scheduler:
                if epoch is None:
                    self.after_scheduler.step(None)
                else:
                    self.after_scheduler.step(epoch - self.total_epoch)
                self._last_lr = self.after_scheduler.get_last_lr()
            else:
                return super(GradualWarmupScheduler, self).step(epoch)
        else:
            self.step_ReduceLROnPlateau(metrics, epoch)

class GradualWarmupSchedulerV2(GradualWarmupScheduler):
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV2, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

class Lookahead(optim.Optimizer):
    def __init__(self, base_optimizer, alpha=0.5, k=6):
        if not 0.0 <= alpha <= 1.0:
            raise ValueError(f'Invalid slow update rate: {alpha}')
        if not 1 <= k:
            raise ValueError(f'Invalid lookahead steps: {k}')
        defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
        self.base_optimizer = base_optimizer
        self.param_groups = self.base_optimizer.param_groups
        self.defaults = base_optimizer.defaults
        self.defaults.update(defaults)
        self.state = defaultdict(dict)
        # manually add our defaults to the param groups
        for name, default in defaults.items():
            for group in self.param_groups:
                group.setdefault(name, default)

    def update_slow(self, group):
        for fast_p in group["params"]:
            if fast_p.grad is None:
                continue
            param_state = self.state[fast_p]
            if 'slow_buffer' not in param_state:
                param_state['slow_buffer'] = torch.empty_like(fast_p.data)
                param_state['slow_buffer'].copy_(fast_p.data)
            slow = param_state['slow_buffer']
            slow.add_(group['lookahead_alpha'], fast_p.data - slow)
            fast_p.data.copy_(slow)

    def sync_lookahead(self):
        for group in self.param_groups:
            self.update_slow(group)

    def step(self, closure=None):
        #assert id(self.param_groups) == id(self.base_optimizer.param_groups)
        loss = self.base_optimizer.step(closure)
        for group in self.param_groups:
            group['lookahead_step'] += 1
            if group['lookahead_step'] % group['lookahead_k'] == 0:
                self.update_slow(group)
        return loss

    def state_dict(self):
        fast_state_dict = self.base_optimizer.state_dict()
        slow_state = {
            (id(k) if isinstance(k, torch.Tensor) else k): v
            for k, v in self.state.items()
        }
        fast_state = fast_state_dict['state']
        param_groups = fast_state_dict['param_groups']
        return {
            'state': fast_state,
            'slow_state': slow_state,
            'param_groups': param_groups,
        }

    def load_state_dict(self, state_dict):
        fast_state_dict = {
            'state': state_dict['state'],
            'param_groups': state_dict['param_groups'],
        }
        self.base_optimizer.load_state_dict(fast_state_dict)

        # We want to restore the slow state, but share param_groups reference
        # with base_optimizer. This is a bit redundant but least code
        slow_state_new = False
        if 'slow_state' not in state_dict:
            print('Loading state_dict from optimizer without Lookahead applied.')
            state_dict['slow_state'] = defaultdict(dict)
            slow_state_new = True
        slow_state_dict = {
            'state': state_dict['slow_state'],
            'param_groups': state_dict['param_groups'],  # this is pointless but saves code
        }
        super(Lookahead, self).load_state_dict(slow_state_dict)
        self.param_groups = self.base_optimizer.param_groups  # make both ref same container
        if slow_state_new:
            # reapply defaults to catch missing lookahead specific ones
            for name, default in self.defaults.items():
                for group in self.param_groups:
                    group.setdefault(name, default)

In [15]:
def log_t(u, t):
    """Compute log_t for `u'."""
    if t==1.0:
        return u.log()
    else:
        return (u.pow(1.0 - t) - 1.0) / (1.0 - t)

def exp_t(u, t):
    """Compute exp_t for `u'."""
    if t==1:
        return u.exp()
    else:
        return (1.0 + (1.0-t)*u).relu().pow(1.0 / (1.0 - t))

def compute_normalization_fixed_point(activations, t, num_iters):

    """Returns the normalization value for each example (t > 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same shape as activation with the last dimension being 1.
    """
    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations_step_0 = activations - mu

    normalized_activations = normalized_activations_step_0

    for _ in range(num_iters):
        logt_partition = torch.sum(
                exp_t(normalized_activations, t), -1, keepdim=True)
        normalized_activations = normalized_activations_step_0 * \
                logt_partition.pow(1.0-t)

    logt_partition = torch.sum(
            exp_t(normalized_activations, t), -1, keepdim=True)
    normalization_constants = - log_t(1.0 / logt_partition, t) + mu

    return normalization_constants

def compute_normalization_binary_search(activations, t, num_iters):

    """Returns the normalization value for each example (t < 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (< 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """

    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations = activations - mu

    effective_dim = \
        torch.sum(
                (normalized_activations > -1.0 / (1.0-t)).to(torch.int32),
            dim=-1, keepdim=True).to(activations.dtype)

    shape_partition = activations.shape[:-1] + (1,)
    lower = torch.zeros(shape_partition, dtype=activations.dtype, device=activations.device)
    upper = -log_t(1.0/effective_dim, t) * torch.ones_like(lower)

    for _ in range(num_iters):
        logt_partition = (upper + lower)/2.0
        sum_probs = torch.sum(
                exp_t(normalized_activations - logt_partition, t),
                dim=-1, keepdim=True)
        update = (sum_probs < 1.0).to(activations.dtype)
        lower = torch.reshape(
                lower * update + (1.0-update) * logt_partition,
                shape_partition)
        upper = torch.reshape(
                upper * (1.0 - update) + update * logt_partition,
                shape_partition)

    logt_partition = (upper + lower)/2.0
    return logt_partition + mu

class ComputeNormalization(torch.autograd.Function):
    """
    Class implementing custom backward pass for compute_normalization. See compute_normalization.
    """
    @staticmethod
    def forward(ctx, activations, t, num_iters):
        if t < 1.0:
            normalization_constants = compute_normalization_binary_search(activations, t, num_iters)
        else:
            normalization_constants = compute_normalization_fixed_point(activations, t, num_iters)

        ctx.save_for_backward(activations, normalization_constants)
        ctx.t=t
        return normalization_constants

    @staticmethod
    def backward(ctx, grad_output):
        activations, normalization_constants = ctx.saved_tensors
        t = ctx.t
        normalized_activations = activations - normalization_constants 
        probabilities = exp_t(normalized_activations, t)
        escorts = probabilities.pow(t)
        escorts = escorts / escorts.sum(dim=-1, keepdim=True)
        grad_input = escorts * grad_output
        
        return grad_input, None, None

def compute_normalization(activations, t, num_iters=5):
    """Returns the normalization value for each example. 
    Backward pass is implemented.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """
    return ComputeNormalization.apply(activations, t, num_iters)

def tempered_sigmoid(activations, t, num_iters = 5):
    """Tempered sigmoid function.
    Args:
      activations: Activations for the positive class for binary classification.
      t: Temperature tensor > 0.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    internal_activations = torch.stack([activations,
        torch.zeros_like(activations)],
        dim=-1)
    internal_probabilities = tempered_softmax(internal_activations, t, num_iters)
    return internal_probabilities[..., 0]


def tempered_softmax(activations, t, num_iters=5):
    """Tempered softmax function.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature > 1.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    if t == 1.0:
        return activations.softmax(dim=-1)

    normalization_constants = compute_normalization(activations, t, num_iters)
    return exp_t(activations - normalization_constants, t)

def bi_tempered_binary_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing = 0.0,
        num_iters=5,
        reduction='mean'):

    """Bi-Tempered binary logistic loss.
    Args:
      activations: A tensor containing activations for class 1.
      labels: A tensor with shape as activations, containing probabilities for class 1
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing
      num_iters: Number of iterations to run the method.
    Returns:
      A loss tensor.
    """
    internal_activations = torch.stack([activations,
        torch.zeros_like(activations)],
        dim=-1)
    internal_labels = torch.stack([labels.to(activations.dtype),
        1.0 - labels.to(activations.dtype)],
        dim=-1)
    return bi_tempered_logistic_loss(internal_activations, 
            internal_labels,
            t1,
            t2,
            label_smoothing = label_smoothing,
            num_iters = num_iters,
            reduction = reduction)

def bi_tempered_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing=0.0,
        num_iters=5,
        reduction = 'mean'):

    """Bi-Tempered Logistic Loss.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      labels: A tensor with shape and dtype as activations (onehot), 
        or a long tensor of one dimension less than activations (pytorch standard)
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing parameter between [0, 1). Default 0.0.
      num_iters: Number of iterations to run the method. Default 5.
      reduction: ``'none'`` | ``'mean'`` | ``'sum'``. Default ``'mean'``.
        ``'none'``: No reduction is applied, return shape is shape of
        activations without the last dimension.
        ``'mean'``: Loss is averaged over minibatch. Return shape (1,)
        ``'sum'``: Loss is summed over minibatch. Return shape (1,)
    Returns:
      A loss tensor.
    """

    if len(labels.shape)<len(activations.shape): #not one-hot
        labels_onehot = torch.zeros_like(activations)
        labels_onehot.scatter_(1, labels[..., None], 1)
    else:
        labels_onehot = labels

    if label_smoothing > 0:
        num_classes = labels_onehot.shape[-1]
        labels_onehot = ( 1 - label_smoothing * num_classes / (num_classes - 1) ) \
                * labels_onehot + \
                label_smoothing / (num_classes - 1)

    probabilities = tempered_softmax(activations, t2, num_iters)

    loss_values = labels_onehot * log_t(labels_onehot + 1e-10, t1) \
            - labels_onehot * log_t(probabilities, t1) \
            - labels_onehot.pow(2.0 - t1) / (2.0 - t1) \
            + probabilities.pow(2.0 - t1) / (2.0 - t1)
    loss_values = loss_values.sum(dim = -1) #sum over classes

    if reduction == 'none':
        return loss_values
    if reduction == 'sum':
        return loss_values.sum()
    if reduction == 'mean':
        return loss_values.mean()

class BiTemperedLogisticLoss(nn.Module): 
    def __init__(self, t1, t2, smoothing=0.0): 
        super(BiTemperedLogisticLoss, self).__init__() 
        self.t1 = t1
        self.t2 = t2
        self.smoothing = smoothing
    def forward(self, logit_label, truth_label):
        loss_label = bi_tempered_logistic_loss(
            logit_label, truth_label,
            t1=self.t1, t2=self.t2,
            label_smoothing=self.smoothing,
            reduction='none'
        )
        
        loss_label = loss_label.mean()
        return loss_label

In [16]:

def train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    torch.cuda.empty_cache()
    gc.collect()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    # switch to train mode
    model.train()
    start = end = time.time()
    truth = []
    pred = []
    global_step = 0
    scaler = GradScaler()
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc='Train')
    for step, (images, labels) in pbar:
        optimizer.zero_grad()
        data_time.update(time.time() - end)
        images = images.to(device)
        
        
        labels = labels.to(device)
        batch_size = labels.size(0)
        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)
            # loss.backward()
            # optimizer.first_step(zero_grad=True)
            # criterion(model(images), labels).backward()
            # optimizer.second_step(zero_grad=True)
            # record loss
        losses.update(loss.item(), batch_size)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # global_step += 1
        scheduler.step()
            # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
#         if step % 100 == 0 or step == (len(train_loader)-1):
#             print('Epoch: [{0}][{1}/{2}] '
#                       'Data {data_time.val:.6f} ({data_time.avg:.6f}) '
#                       'Elapsed {remain:s} '
#                       'Loss: {loss.val:.6f}({loss.avg:.6f}) '
#                       'LR: {lr:.6f}  '
#                       .format(
#                        epoch+1, step, len(train_loader), batch_time=batch_time,
#                        data_time=data_time, loss=losses,
#                        remain=timeSince(start, float(step+1)/len(train_loader)),
#                        lr=scheduler.get_lr()[0],
#                        ))
        torch.cuda.empty_cache()
        gc.collect()
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix(train_loss=f'{losses.avg:0.4f}',
                        lr=f'{current_lr:0.8f}',
                        gpu_mem=f'{mem:0.2f} GB')

    return losses.avg

def valid_fn_no_sigmoid(val_dataloader, model, criterion, device):
    losses = AverageMeter()
    model.eval()
    truth = []
    preds = []
    valid_labels = []
    start = end = time.time()
    pbar = tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc='Val')
    for step, (images, labels) in pbar:
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        with torch.no_grad():
            outputs = model(images)
        valid_labels.append(labels.cpu().numpy())
        loss = criterion(outputs, labels)
#         loss = bi_tempered_logistic_loss(outputs, labels, t1=0.8, t2 = 1.4)
        losses.update(loss.item(), batch_size)
#         print(outputs)
        preds.append((outputs).to('cpu').numpy())
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        pbar.set_postfix(eval_loss=f'{losses.avg:0.4f}',
                        gpu_mem=f'{mem:0.2f} GB')
    predictions = np.concatenate(preds)
    valid_labels = np.concatenate(valid_labels)
    return losses.avg, predictions, valid_labels


def valid_fn(val_dataloader, model, criterion, device):
    losses = AverageMeter()
    model.eval()
    truth = []
    preds = []
    valid_labels = []
    start = end = time.time()
    pbar = tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc='Val')
    for step, (images, labels) in pbar:
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        with torch.no_grad():
            outputs = model(images)
        valid_labels.append(labels.cpu().numpy())
        loss = criterion(outputs, labels)
#         loss = bi_tempered_logistic_loss(outputs, labels, t1=0.8, t2 = 1.4)
        losses.update(loss.item(), batch_size)
#         print(outputs)
        preds.append(torch.sigmoid(outputs).to('cpu').numpy())
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        pbar.set_postfix(eval_loss=f'{losses.avg:0.4f}',
                        gpu_mem=f'{mem:0.2f} GB')
    predictions = np.concatenate(preds)
    valid_labels = np.concatenate(valid_labels)
    return losses.avg, predictions, valid_labels
def valid_fn_two(val_dataloader, model, criterion, device):
    losses = AverageMeter()
    model.eval()
    truth = []
    preds = []
    valid_labels = []
    start = end = time.time()
    pbar = tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc='Val')
    for step, (images, labels) in pbar:
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        with torch.no_grad():
            outputs = model(images)
        valid_labels.append(labels.cpu().numpy())
        loss = criterion(outputs, labels)
#         loss = bi_tempered_logistic_loss(outputs, labels, t1=0.8, t2 = 1.4)
        losses.update(loss.item(), batch_size)
#         print(outputs)
        preds.append(F.softmax(outputs).to('cpu').numpy())
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        pbar.set_postfix(eval_loss=f'{losses.avg:0.4f}',
                        gpu_mem=f'{mem:0.2f} GB')
    predictions = np.concatenate(preds)
    valid_labels = np.concatenate(valid_labels)
    return losses.avg, predictions, valid_labels

def valid_fn_two_flip(val_dataloader, model, criterion, device):
    losses = AverageMeter()
    model.eval()
    truth = []
    preds = []
    valid_labels = []
    start = end = time.time()
    pbar = tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc='Val')
    for step, (images, labels) in pbar:
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        images = torch.flip(images, [3])
        with torch.no_grad():
            outputs = model(images)
        valid_labels.append(labels.cpu().numpy())
        loss = criterion(outputs, labels)
#         loss = bi_tempered_logistic_loss(outputs, labels, t1=0.8, t2 = 1.4)
        losses.update(loss.item(), batch_size)
#         print(outputs)
        preds.append(F.softmax(outputs).to('cpu').numpy())
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        pbar.set_postfix(eval_loss=f'{losses.avg:0.4f}',
                        gpu_mem=f'{mem:0.2f} GB')
    predictions = np.concatenate(preds)
    valid_labels = np.concatenate(valid_labels)
    return losses.avg, predictions, valid_labels

def valid_fn_flip(val_dataloader, model, criterion, device):
    losses = AverageMeter()
    model.eval()
    truth = []
    preds = []
    valid_labels = []
    start = end = time.time()
    pbar = tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc='Val')
    for step, (images, labels) in pbar:
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        images = torch.flip(images, [3])
        with torch.no_grad():
            outputs = model(images)
        valid_labels.append(labels.cpu().numpy())
        loss = criterion(outputs, labels)
#         loss = bi_tempered_logistic_loss(outputs, labels, t1=0.8, t2 = 1.4)
        losses.update(loss.item(), batch_size)
#         print(outputs)
        preds.append(torch.sigmoid(outputs).to('cpu').numpy())
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        pbar.set_postfix(eval_loss=f'{losses.avg:0.4f}',
                        gpu_mem=f'{mem:0.2f} GB')
    predictions = np.concatenate(preds)
    valid_labels = np.concatenate(valid_labels)
    return losses.avg, predictions, valid_labels

In [17]:
from exhaustive_weighted_random_sampler import ExhaustiveWeightedRandomSampler
def pfbeta(labels, predictions, beta=1):
    y_true_count = 0
    ctp = 0
    cfp = 0

    for idx in range(len(labels)):
        prediction = min(max(predictions[idx], 0), 1)
        if (labels[idx]):
            y_true_count += 1
            ctp += prediction
        else:
            cfp += prediction

    beta_squared = beta * beta
    c_precision = ctp / (ctp + cfp)
    c_recall = ctp / y_true_count
    if (c_precision > 0 and c_recall > 0):
        result = (1 + beta_squared) * (c_precision * c_recall) / (beta_squared * c_precision + c_recall)
        return result
    else:
        return 0
    
def dfs_freeze(module):
    for param in module.parameters():
        param.requires_grad = False
        
def dfs_unfreeze(module):
    for param in module.parameters():
        param.requires_grad = True
    
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')

def sigmoid(x):
  return 1 / (1 + math.exp(-x))

set_seed(1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gc.collect()
torch.cuda.empty_cache()
for fold in [0, 1, 2, 3, 4]:
    LOGGER.info(f"Fold: {fold}")
    LOGGER.info(f"Model name: {CFG.model_name}")
    # LOGGER.info(f"Model name: nextvit-small")
    # model = ModelNextVit().to(CFG.device)
    model = Model(model_name=CFG.model_name).to(device)
    # model = ModelVIT().to(CFG.device)
    train_df = df1[df1['fold']!=fold].reset_index(drop=True)
    valid_df = df[df['fold']==fold].reset_index(drop=True)
    # print(len(valid_df))
    LOGGER.info(f"Len train df: {len(train_df)}")
    LOGGER.info(f"Len valid df: {len(valid_df)}")
    # cancer_labels = train_df['cancer'].values.tolist()
    # class_zero =len(train_df[train_df['cancer']==0])
    # class_one = len(train_df[train_df['cancer']==1])
    # class_sample_count = np.array([class_zero, class_one*32])
    # weight = 1. / class_sample_count
    # samples_weight = np.array([weight[t] for t in cancer_labels])
    # samples_weight = torch.from_numpy(samples_weight)
    # samples_weight = samples_weight.double()
    # print(samples_weight)
    # sampler = torch.utils.data.sampler.WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)
    
    train_dataset = BreastDataset(train_df, transforms=data_transforms['train'])

    train_loader = DataLoader(train_dataset, batch_size = CFG.train_bs,
                                  num_workers=1, shuffle=True, pin_memory=True, drop_last=True)
    
    valid_dataset = BreastDataset(valid_df, transforms=data_transforms['valid'])

    valid_loader = DataLoader(valid_dataset, batch_size = CFG.valid_bs, 
                                  num_workers=1, shuffle=False, pin_memory=True, drop_last=False)
    
    LEN_DL_TRAIN = len(train_loader)
    best_f1 = 0
    best_metric = 0
    total_epoch = 13
    #checkpoint = torch.load("fold0/tf_efficientnetv2_b2_fold_0_model_epoch_1_0.0476_0.098.pth")
    #model.load_state_dict(checkpoint['state_dict'])
    # base_optimizer =torch.optim.AdamW
    # optimizer = SAM(model.parameters(),
    #                 base_optimizer,
    #                 rho=0.05,
    #                 lr=1e-4,
    #                 weight_decay=5e-4)
    optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-4, weight_decay=5e-4)
    # optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)  
    #optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps = 1*LEN_DL_TRAIN, num_training_steps =total_epoch*LEN_DL_TRAIN)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_epoch)
    # scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_epoch-1)
    # scheduler = GradualWarmupSchedulerV2(optimizer, multiplier=10, total_epoch=1, after_scheduler=scheduler_cosine)
    # swa_model = AveragedModel(model)
    # swa_scheduler = SWALR(optimizer, swa_lr=1e-4, anneal_epochs=0)
    #scheduler.load_state_dict(checkpoint['scheduler'])
    criterion = nn.CrossEntropyLoss().to(CFG.device)
    # criterion = BiTemperedLogisticLoss(t1=0.3, t2=1.00, smoothing=0.05).to(device)
    # criterion1 = nn.BCEWithLogitsLoss().to(device)
    # criterion = nn.BCEWithLogitsLoss().to(CFG.device)
    LOGGER.info(f"Train bs: {CFG.train_bs}")
    # LOGGER.info(f"Model: {model}")
    
    LOGGER.info(f"optimizer: {optimizer}")
    LOGGER.info(f"total_epoch :{total_epoch}")
#     criterion = FocalLoss().to(device)
    for epoch in range(2, total_epoch+1):
        # if epoch >=7:
        #     swa_model.update_parameters(model)
        #     swa_scheduler.step()
        # else:
        # scheduler.step(epoch-1)
        LOGGER.info(f"Epoch: {epoch}/{total_epoch}")
        loss_train = train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, CFG.device)
        # state = {'epoch': epoch, 'state_dict': model.state_dict(),'optimizer': optimizer.state_dict(), 'scheduler':scheduler.state_dict()}
        # path = f'{CFG.model_name}_fold_{fold}_model_epoch_{epoch}.pth'
        # torch.save(state, path
        loss_valid, valid_preds, valid_labels = valid_fn_two(valid_loader, model, criterion, CFG.device)
        valid_preds = valid_preds[:, 1]
        # loss_valid, valid_preds_flip, valid_labels1 = valid_fn_two_flip(valid_loader, model, criterion, device)
        # print(valid_preds)
        # valid_preds = valid_preds[:, 1].reshape(-1, 1)
        # valid_preds_flip = valid_preds_flip[:, 1].reshape(-1, 1)
        # print(valid_preds.shape)
        # print(valid_preds_flip.shape)
        # valid_preds_final = np.reshape(np.average(np.concatenate([np.array(valid_preds), np.array(valid_preds_flip)], axis = 1), axis = 1), (-1, 1))
        # print(valid_preds_final.shape)
        valid_df['prediction_id'] = valid_df['patient_id'].astype(str) + '_' + valid_df['laterality'].astype(str)
        valid_preds = np.array(valid_preds).flatten()
        
        valid_df['raw_pred'] = valid_preds
        # LOGGER.info(f"Valid loss:{loss_valid:.4f}")
        LOGGER.info(f"Train loss:{loss_train:.4f}, Valid loss:{loss_valid:.4f}")
        # print(valid_df.head())
        grp_df = valid_df.groupby('prediction_id')['raw_pred', 'cancer'].mean()
        grp_df['cancer'] = grp_df['cancer'].astype(np.int)
        valid_labels_mean = grp_df['cancer'].values.tolist()
        valid_preds_mean = grp_df['raw_pred'].values.tolist()
        # print(valid_labels[:5], valid_preds_mean[:5])
        val_metric_mean = pfbeta(valid_labels_mean, valid_preds_mean)
        LOGGER.info(f"Val metric mean prob: {val_metric_mean:.4f}")
        best_metric_mean_at_epoch = 0
        best_threshold_mean = 0
        best_auc = 0
        best_cf = None
        for i in np.arange(0.001, 0.599, 0.001):
            valid_argmax = (valid_preds_mean>i).astype(np.int32)
    #             print(valid_argmax)
            val_metric = pfbeta(valid_labels_mean, valid_argmax)
            val_acc = accuracy_score(valid_labels_mean, valid_argmax)
            val_f1 = f1_score(valid_labels_mean, valid_argmax)
            val_auc = roc_auc_score(valid_labels_mean, valid_argmax)
            cf = confusion_matrix(valid_labels_mean, valid_argmax)
            if val_metric> best_metric_mean_at_epoch:
                best_metric_mean_at_epoch = val_metric
                best_threshold_mean = i
                best_auc = val_auc
                best_cf = cf
            # print(f"Threshold: {i:.4f}, val_acc: {val_acc:.4f}, val_f1: {val_f1:.4f}, val_auc: {val_auc:.4f}, val_metric: {val_metric:.4f}")
        LOGGER.info(f"Best metric at epoch {epoch}: {best_metric_mean_at_epoch:.4f} {best_threshold_mean:.4f} {best_auc:.4f}")
        LOGGER.info(f"Cf: {best_cf}")
    #         print(f"Train loss: {loss_train:.4f}, eval loss: {loss_valid.avg:.4f}") 
    #         print(f"Accuracy score: {val_acc:.4f}, f1 score: {val_f1:.4f}")
    #         print(f"Comp metric: {val_metric:.4f}")
        if best_metric_mean_at_epoch > best_metric:

            LOGGER.info(f"Model improve: {best_metric:.4f} -> {best_metric_mean_at_epoch:.4f}")
            best_metric = best_metric_mean_at_epoch
        state = {'epoch': epoch, 'state_dict': model.state_dict(),'optimizer': optimizer.state_dict(), 'scheduler':scheduler.state_dict()}
        path = f'fold{fold}/{CFG.model_name}_fold_{fold}_model_epoch_{epoch}_{best_metric_mean_at_epoch:.4f}_{best_threshold_mean:.3f}.pth'
        torch.save(state, path)
    #     loss_valid, valid_preds, valid_labels = valid_fn_no_sigmoid(valid_loader, model, criterion, device)
        
    #     # valid_preds = valid_preds[:, 1]
    #     valid_df['prediction_id'] = valid_df['patient_id'].astype(str) + '_' + valid_df['laterality'].astype(str)
    #     valid_preds = np.array(valid_preds).flatten()
        
    #     valid_df['raw_pred'] = valid_preds
    #     # LOGGER.info(f"Valid loss:{loss_valid:.4f}")
    #     LOGGER.info(f"Train loss:{loss_train:.4f}, Valid loss:{loss_valid:.4f}")
    #     # print(valid_df.head())
    #     grp_df = valid_df.groupby('prediction_id')['raw_pred', 'cancer'].mean()
    #     grp_df['cancer'] = grp_df['cancer'].astype(np.int)
    #     valid_labels_mean = grp_df['cancer'].values.tolist()
    #     valid_preds_mean = grp_df['raw_pred'].values.tolist()
    #     valid_preds_mean = [sigmoid(x) for x in valid_preds_mean]
    #     # print(valid_labels[:5], valid_preds_mean[:5])
    #     val_metric_mean = pfbeta(valid_labels_mean, valid_preds_mean)
    #     LOGGER.info(f"Val metric mean prob: {val_metric_mean:.4f}")
    #     best_metric_mean_at_epoch = 0
    #     best_threshold_mean = 0
    #     best_auc = 0
    #     best_cf = None
    #     for i in np.arange(0.001, 0.999, 0.001):
    #         valid_argmax = (valid_preds_mean>i).astype(np.int32)
    # #             print(valid_argmax)
    #         val_metric = pfbeta(valid_labels_mean, valid_argmax)
    #         val_acc = accuracy_score(valid_labels_mean, valid_argmax)
    #         val_f1 = f1_score(valid_labels_mean, valid_argmax)
    #         val_auc = roc_auc_score(valid_labels_mean, valid_argmax)
    #         cf = confusion_matrix(valid_labels_mean, valid_argmax)
    #         if val_metric> best_metric_mean_at_epoch:
    #             best_metric_mean_at_epoch = val_metric
    #             best_threshold_mean = i
    #             best_auc = val_auc
    #             best_cf = cf
    #         # print(f"Threshold: {i:.4f}, val_acc: {val_acc:.4f}, val_f1: {val_f1:.4f}, val_auc: {val_auc:.4f}, val_metric: {val_metric:.4f}")
    #     LOGGER.info(f"Best metric at epoch {epoch}: {best_metric_mean_at_epoch:.4f} {best_threshold_mean:.4f} {best_auc:.4f}")
    #     LOGGER.info(f"Cf: {best_cf}")
    # #         print(f"Train loss: {loss_train:.4f}, eval loss: {loss_valid.avg:.4f}") 
    # #         print(f"Accuracy score: {val_acc:.4f}, f1 score: {val_f1:.4f}")
    # #         print(f"Comp metric: {val_metric:.4f}")
    #     if best_metric_mean_at_epoch > best_metric:

    #         LOGGER.info(f"Model improve: {best_metric:.4f} -> {best_metric_mean_at_epoch:.4f}")
    #         best_metric = best_metric_mean_at_epoch
    #     state = {'epoch': epoch+1, 'state_dict': model.state_dict(),'optimizer': optimizer.state_dict(), 'scheduler':scheduler.state_dict()}
    #     path = f'{CFG.model_name}_fold_{fold}_model_epoch_{epoch+1}_{best_metric_mean_at_epoch:.4f}_{best_threshold_mean:.3f}.pth'
    #     torch.save(state, path)+









    
#     torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
        

Fold: 0
Model name: tf_efficientnetv2_b2


> SEEDING DONE


Len train df: 50278
Len valid df: 5471
Train bs: 16
optimizer: AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    initial_lr: 0.0001
    lr: 0.0
    maximize: False
    weight_decay: 0.0005
)
total_epoch :13
Epoch: 2/13
Train:   4%|▍         | 136/3142 [01:24<31:18,  1.60it/s, gpu_mem=0.48 GB, lr=0.00000433, train_loss=0.7055]


KeyboardInterrupt: 

In [None]:
out_file = 'swa_model_fold1.pth' 
iteration = [
    # 'fold1/tf_efficientnetv2_b2_fold_1_model_epoch_4_0.4385_0.205.pth',
    # 'fold1/tf_efficientnetv2_b2_fold_1_model_epoch_5_0.4393_0.278.pth',
    # 'fold1/tf_efficientnetv2_b2_fold_1_model_epoch_6_0.4432_0.319.pth',
    # 'fold1/tf_efficientnetv2_b2_fold_1_model_epoch_8_0.4231_0.320.pth',
    # 'fold1/tf_efficientnetv2_b2_fold_1_model_epoch_7_0.4578_0.382.pth',
    # 'fold1/tf_efficientnetv2_b2_fold_1_model_epoch_10_0.4339_0.246.pth',
    # 'fold1/tf_efficientnetv2_b2_fold_1_model_epoch_11_0.4211_0.242.pth',
    
    # 'fold0/tf_efficientnetv2_b2_fold_0_model_epoch_4_0.4151_0.352.pth',
    # 'fold0/tf_efficientnetv2_b2_fold_0_model_epoch_5_0.4757_0.230.pth',
    # 'fold0/tf_efficientnetv2_b2_fold_0_model_epoch_6_0.4520_0.128.pth',
    # 'fold0/tf_efficientnetv2_b2_fold_0_model_epoch_7_0.4510_0.266.pth',
    # 'fold0/tf_efficientnetv2_b2_fold_0_model_epoch_8_0.4403_0.415.pth',
    # 'fold0/tf_efficientnetv2_b2_fold_0_model_epoch_9_0.4713_0.430.pth',
    # 'fold0/tf_efficientnetv2_b2_fold_0_model_epoch_10_0.4569_0.259.pth',
    # 'fold0/tf_efficientnetv2_b2_fold_0_model_epoch_11_0.4387_0.436.pth'
    # 'fold2/tf_efficientnetv2_b2_fold_2_model_epoch_3_0.4000_0.122.pth',
    # 'fold2/tf_efficientnetv2_b2_fold_2_model_epoch_4_0.4585_0.236.pth',
    # 'fold2/tf_efficientnetv2_b2_fold_2_model_epoch_5_0.4149_0.131.pth',
    # 'fold2/tf_efficientnetv2_b2_fold_2_model_epoch_6_0.4516_0.188.pth',
    # 'fold2/tf_efficientnetv2_b2_fold_2_model_epoch_7_0.4557_0.241.pth',
    # 'fold2/tf_efficientnetv2_b2_fold_2_model_epoch_8_0.4455_0.208.pth',
    # 'fold2/tf_efficientnetv2_b2_fold_2_model_epoch_9_0.4681_0.319.pth',
    # 'fold2/tf_efficientnetv2_b2_fold_2_model_epoch_10_0.4550_0.245.pth',
    # 'fold2/tf_efficientnetv2_b2_fold_2_model_epoch_11_0.4500_0.373.pth',
    # 'fold2/tf_efficientnetv2_b2_fold_2_model_epoch_12_0.4457_0.298.pth',
    # 'fold3/tf_efficientnetv2_b2_fold_3_model_epoch_3_0.3867_0.302.pth',
    # 'fold3/tf_efficientnetv2_b2_fold_3_model_epoch_4_0.3924_0.275.pth',
    # 'fold3/tf_efficientnetv2_b2_fold_3_model_epoch_6_0.4030_0.339.pth',
    # 'fold3/tf_efficientnetv2_b2_fold_3_model_epoch_5_0.3850_0.161.pth',
    # 'fold3/tf_efficientnetv2_b2_fold_3_model_epoch_7_0.4192_0.270.pth',
    # 'fold3/tf_efficientnetv2_b2_fold_3_model_epoch_8_0.3913_0.362.pth'
    # 'fold3/tf_efficientnetv2_b2_fold_3_model_epoch_4_0.4103_0.343.pth',
    # 'fold3/tf_efficientnetv2_b2_fold_3_model_epoch_5_0.4041_0.141.pth',
    # 'fold3/tf_efficientnetv2_b2_fold_3_model_epoch_6_0.4648_0.444.pth',
    # 'fold3/tf_efficientnetv2_b2_fold_3_model_epoch_7_0.4103_0.310.pth',
    # 'fold3/tf_efficientnetv2_b2_fold_3_model_epoch_8_0.4471_0.371.pth',
    # 'fold3/tf_efficientnetv2_b2_fold_3_model_epoch_10_0.4062_0.202.pth',
    # 'fold3/tf_efficientnetv2_b2_fold_3_model_epoch_7_0.4192_0.270.pth',
    # 'fold3/tf_efficientnetv2_b2_fold_3_model_epoch_11_0.4309_0.199.pth',
    # 'fold3/tf_efficientnetv2_b2_fold_3_model_epoch_12_0.4074_0.278.pth'
    # 'fold4/tf_efficientnetv2_b2_fold_4_model_epoch_5_0.3889_0.407.pth',
    # 'fold4/tf_efficientnetv2_b2_fold_4_model_epoch_4_0.4276_0.403.pth',
    # 'fold4/tf_efficientnetv2_b2_fold_4_model_epoch_6_0.4000_0.586.pth',
    # 'fold4/tf_efficientnetv2_b2_fold_4_model_epoch_7_0.3913_0.444.pth',
    # 'fold4/tf_efficientnetv2_b2_fold_4_model_epoch_8_0.3916_0.483.pth'

]
#46789101112 4824
state_dict = None
for i in iteration:
    f = i
    print(f)
    f = torch.load(f, map_location=lambda storage, loc: storage)
    if state_dict is None:
        state_dict = f['state_dict']
    else:
        key = list(f['state_dict'].keys())
        for k in key:
            state_dict[k] = state_dict[k] + f['state_dict'][k]

for k in key:
    state_dict[k] = state_dict[k] / len(iteration)
print('')

print(out_file)
torch.save({'state_dict': state_dict}, out_file)

fold1/tf_efficientnetv2_b2_fold_1_model_epoch_4_0.4385_0.205.pth
fold1/tf_efficientnetv2_b2_fold_1_model_epoch_5_0.4393_0.278.pth
fold1/tf_efficientnetv2_b2_fold_1_model_epoch_6_0.4432_0.319.pth
fold1/tf_efficientnetv2_b2_fold_1_model_epoch_8_0.4231_0.320.pth
fold1/tf_efficientnetv2_b2_fold_1_model_epoch_7_0.4578_0.382.pth
fold1/tf_efficientnetv2_b2_fold_1_model_epoch_10_0.4339_0.246.pth
fold1/tf_efficientnetv2_b2_fold_1_model_epoch_11_0.4211_0.242.pth
fold1/tf_efficientnetv2_b2_fold_1_model_epoch_6_0.4235_0.473.pth

swa_model_fold1.pth


In [None]:
avail_pretrained_models = timm.list_models(pretrained=True)
print(avail_pretrained_models)

['adv_inception_v3', 'bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_224_in22k', 'beit_base_patch16_384', 'beit_large_patch16_224', 'beit_large_patch16_224_in22k', 'beit_large_patch16_384', 'beit_large_patch16_512', 'beitv2_base_patch16_224', 'beitv2_base_patch16_224_in22k', 'beitv2_large_patch16_224', 'beitv2_large_patch16_224_in22k', 'botnet26t_256', 'cait_m36_384', 'cait_m48_448', 'cait_s24_224', 'cait_s24_384', 'cait_s36_384', 'cait_xs24_384', 'cait_xxs24_224', 'cait_xxs24_384', 'cait_xxs36_224', 'cait_xxs36_384', 'coat_lite_mini', 'coat_lite_small', 'coat_lite_tiny', 'coat_mini', 'coat_tiny', 'coatnet_0_rw_224', 'coatnet_1_rw_224', 'coatnet_bn_0_rw_224', 'coatnet_nano_rw_224', 'coatnet_rmlp_1_rw_224', 'coatnet_rmlp_nano_rw_224', 'convit_base', 'convit_small', 'convit_tiny', 'convmixer_768_32', 'convmixer_1024_20_ks9_p14', 'convmixer_1536_20', 'convnext_atto', 'convnext_atto_ols', 'convnext_base', 'convnext_base_384_in22ft1k', 'convnext_base_in22ft1k', 'convnext_base

In [None]:
# for fold in [0, 1, 2, 3, 4]:
#     LOGGER.info(f"Fold: {fold}")
#     model = Model(model_name=CFG.model_name).to(device)
#     # model = ModelVIT().to(CFG.device)
#     train_df = df1[df1['fold']!=fold].reset_index(drop=True)
#     valid_df = df[df['fold']==fold].reset_index(drop=True)
#     # print(len(valid_df))
#     LOGGER.info(f"Len train df: {len(train_df)}")
#     LOGGER.info(f"Len valid df: {len(valid_df)}")

Fold: 0
Len train df: 45577
Len valid df: 10979
Fold: 1
Len train df: 45673
Len valid df: 10879
Fold: 2
Len train df: 45512
Len valid df: 11012
Fold: 3
Len train df: 45689
Len valid df: 10891
Fold: 4
Len train df: 45637
Len valid df: 10945
