In [None]:
import torch
from torch import nn
import timm
import os


In [None]:
in_chans=4
out_dim_be=2
out_dim_lsk=3
drop_rate=0
drop_path_rate=0
drop_rate_last=0
image_size=296
backbone_bowel="tf_efficientnetv2_b0.in1k"
backbone_lsk="tf_efficientnetv2_b2.in1k"
model_bowel_dir='./models_bowel'
kernel_bowel_type='0920_1bonev2_effv2s_224_15_6ch_augv2_mixupp5_drl3_rov1p2_bs8_lr23e5_eta23e6_50ep'
device="cuda"

In [None]:
class Attention(nn.Module):
    def __init__(self, feature_dim, **kwargs):
        super().__init__(**kwargs)
        
        self.supports_masking = True

        self.feature_dim = feature_dim
        self.features_dim = 0
        
        weight = torch.zeros(feature_dim, 1)
        nn.init.xavier_uniform_(weight)
        self.weight = nn.Parameter(weight)
        
    def forward(self, x, mask=None):
        step_dim=x.shape[1]
        feature_dim = self.feature_dim
        eij = torch.mm(
            x.contiguous().view(-1, feature_dim), 
            self.weight
        ).view(-1, step_dim)
        
        eij = torch.tanh(eij)
        a = torch.exp(eij) # slice importances
        
        if mask is not None:
            a = a * mask

        a = a / torch.sum(a, 1, keepdim=True) + 1e-10 # normalize across slices

        weighted_input = x * torch.unsqueeze(a, -1)
        return torch.sum(weighted_input, 1)


In [None]:
class TimmModelBowel(nn.Module):
    def __init__(self, backbone, pretrained=False,features=False):
        super().__init__()
        self.features=features
        self.encoder = timm.create_model(
            backbone,
            in_chans=in_chans,
            num_classes=out_dim_be,
            features_only=False,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )

        if 'efficient' in backbone:
            hdim = self.encoder.conv_head.out_channels
            self.encoder.classifier = nn.Identity()
        elif 'convnext' in backbone:
            hdim = self.encoder.head.fc.in_features
            self.encoder.head.fc = nn.Identity()

        hlstm=64
        self.lstm = nn.LSTM(hdim, hlstm, num_layers=1, dropout=drop_rate, bidirectional=True, batch_first=True)
        self.head = nn.Sequential(
            nn.Linear(2*hlstm, hlstm),
            nn.BatchNorm1d(hlstm),
            nn.Dropout(drop_rate_last),
            nn.LeakyReLU(0.1),
            nn.Linear(hlstm, out_dim_be),
        )
        self.head2=nn.Linear(4*hlstm,out_dim_be)
        self.attention=Attention(2*hlstm)

    def forward(self, x):  # (bs, nslice, ch, sz, sz)
        bs = x.shape[0]
        nslices=x.shape[1]
        x = x.view(bs * nslices, in_chans, image_size, image_size)
        feat = self.encoder(x)
        feat = feat.view(bs, nslices, -1)
        feat, _ = self.lstm(feat)
        
        featv = feat.contiguous().view(bs * nslices, -1)
        out = self.head(featv)
        out = out.view(bs, nslices,out_dim_be).contiguous()
        
        att=self.attention(feat)
        max_=feat.max(dim=1)[0]
        conc=torch.cat((att,max_),dim=1)
        out2=self.head2(conc)


        if self.features:
            return feat
        else:
            return out,out2

In [None]:

def infer(ifold):
    print(ifold)
    model_bowel = TimmModelBowel(backbone_bowel, pretrained=False,features=True)
    model_bowel_file = os.path.join(model_bowel_dir, f'{kernel_bowel_type}_fold{ifold}_best.pth')
    model_bowel.load_state_dict(torch.load(model_bowel_file))
    model_bowel = model.to(device)

    dataset = CLSDataset(df, 'test', transform=transforms_valid) 
    save_path=f"features2b/model{ifold}"           
    os.makedirs(save_path,exist_ok=True)
    for ind,row in tqdm(df.iterrows()):
        images = dataset[ind]
        with torch.no_grad():
            images=images.cuda()
            images=images[None,...]
            features = model(images).squeeze()
            features=features.numpy(force=True)
        np.save(f"{save_path}/{row.patient_id}_{row.series_id}",features)
    
    del model
    torch.cuda.empty_cache()
    gc.collect()
    
PREP_FEATURES=True
if PREP_FEATURES:
    infer_features(0)
    infer_features(1)
    infer_features(2)
    infer_features(3)
    infer_features(4)
    