In [None]:
! pip install timm

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import torch.nn as nn
import timm
import albumentations

from sklearn.model_selection import train_test_split
from sklearn import metrics
from torch.utils.data import Dataset, DataLoader
from skimage import io, color

In [None]:
class BratsDataset(Dataset):
    def __init__(self, image_path, MGMT_value, transforms=None):
        self.image_path = image_path
        self.MGMT_value = MGMT_value
        self.transforms = transforms
        
    def __len__(self):
        return len(self.image_path)
    
    def __getitem__(self, item):
        image = io.imread(self.image_path[item])
        MGMT_value = self.MGMT_value[item]
        
        image = color.gray2rgb(image)

        if self.transforms is not None:
            augmented = self.transforms(image=image)
            image = augmented["image"]
            
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        return {
            'image': torch.tensor(image, dtype=torch.float),
            'targets': torch.tensor(MGMT_value, dtype=torch.float)
        }


In [None]:
def train(model,train_loader,device,optimizer):
    model.train()
    running_train_loss = 0.0
    for data in train_loader:
        inputs = data['image']
        targets = data['targets']

        inputs = inputs.to(device, dtype=torch.float)
        targets = targets.to(device, dtype=torch.float)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))
        loss.backward()
        optimizer.step()
        running_train_loss +=loss.item()
        
    train_loss_value = running_train_loss/len(train_loader)
    print(f'train BCE loss is {train_loss_value}')
    
def eval(model,valid_loader,device,optimizer):
    model.eval()
    final_targets = []
    final_outputs = []
    running_val_loss = 0.0
    with torch.no_grad():
        for data in valid_loader:
            inputs = data['image']
            targets = data['targets']
            inputs = inputs.to(device, dtype=torch.float)
            targets = targets.to(device, dtype=torch.float)

            output = model(inputs)
            running_val_loss += nn.BCEWithLogitsLoss()(output, targets.view(-1, 1))
            targets = (targets.detach().cpu().numpy()).tolist()
            output = (torch.sigmoid(output).detach().cpu().numpy()).tolist()
            final_outputs.extend(output)
            final_targets.extend(targets)
        val_loss = running_val_loss/len(valid_loader)    
        print(f'valid BCE loss is {val_loss}')
    return final_outputs,final_targets

In [None]:
import torch
import torch.nn as nn

model_name = 'swin_base_patch4_window7_224'

out_dim    = 1

class get_model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=True)
        self.model.head = nn.Sequential(nn.Linear(self.model.head.in_features,768),
                                              nn.Linear(768,256))
        self.last = nn.Linear(256, 128)
        self.depth1 = nn.Linear(128,64)
        self.depth2 = nn.Linear(64,1)
    def forward(self, image):
        x = self.model(image)
        x = self.last(x)
        x = self.depth1(x)
        x = self.depth2(x)
        return x

# Dataloaders

In [None]:
df = pd.read_csv('../input/train-png-middlecsv/train_rsna_png_T1wCE.csv')

df_train, df_valid = train_test_split(df, test_size = 0.3, stratify=df.MGMT_value)

image_path_train = df_train.path.values.tolist()
image_path_valid = df_valid.path.values.tolist()

aug = albumentations.Compose(
[   albumentations.Resize(224, 224, p=1),
    albumentations.Normalize(
    (0.485, 0.456, 0.406),(0.229, 0.224, 0.225),max_pixel_value=255.0,always_apply=True)
])
train_dataset = BratsDataset(image_path_train, df_train.MGMT_value.values, transforms=aug )
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

valid_dataset = BratsDataset(image_path_valid, df_valid.MGMT_value.values, transforms=aug )
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)

In [None]:
from itertools import chain
device = 'cuda'

model = get_model()
model.to(device)

optimizer = torch.optim.Adam(model.parameters(),lr=1e-6)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=1, eta_min=1e-4, last_epoch=- 1, verbose=True)

for epoch in range(18):
    print(f'==================== Epoch -- {epoch} ====================')
    train(model=model,train_loader=train_loader,device=device,optimizer=optimizer)

    final_outputs,final_targets = eval(model=model,valid_loader=valid_loader,device=device,optimizer=optimizer)
    
    final_outputs = list(chain.from_iterable(final_outputs))     
    
    ROC = np.sqrt(metrics.roc_auc_score(final_targets,final_outputs))
#     scheduler.step()

    print(f'valid ROC={ROC}')
    
torch.save(model.state_dict(),'model-epoch'+'.pth')