This notebook is based on 2nd place solution of RSNA Intracranial Hemorrhage Detection competetion.  
https://www.kaggle.com/c/rsna-intracranial-hemorrhage-detection/discussion/117228

In [None]:
import ast
import gc
import math
import os

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torchvision import models

n_gpu = torch.cuda.device_count()
print("n_gpus: ", n_gpu)

In [None]:
def set_seeds(SEED):
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True

set_seeds(SEED=2020)

In [None]:
label_cols = [
    "negative_exam_for_pe",
    "rv_lv_ratio_gte_1",
    "rv_lv_ratio_lt_1",
    "leftsided_pe",
    "chronic_pe",
    "rightsided_pe",
    "acute_and_chronic_pe",
    "central_pe",
    "indeterminate",
]

n_classes = len(label_cols)
n_epochs = 10
batch_size = 8
num_workers = os.cpu_count()

# LSTM_UNITS = 512
lr = 1e-5
lrgamma = 0.95
DECAY = 0.0

In [None]:
trnmdf = pd.read_csv('../input/extract-pe-meta-data/train_metadata.csv')
# tstmdf = pd.read_csv(os.path.join(path_data, 'test_metadata.csv'))

In [None]:
%%time

trnmdf['SliceID'] = trnmdf[['SeriesInstanceUID', 'StudyInstanceUID']].apply(
    lambda x: '{}__{}'.format(*x.tolist()), 1)
# tstmdf['SliceID'] = tstmdf[['SeriesInstanceUID', 'StudyInstanceUID']].apply(
#     lambda x: '{}__{}'.format(*x.tolist()), 1)

In [None]:
trnmdf.head()

In [None]:
%%time

poscols = ['ImagePos{}'.format(i) for i in range(1, 4)]
trnmdf[poscols] = pd.DataFrame(trnmdf['ImagePositionPatient']\
              .apply(lambda x: list(map(float, ast.literal_eval(x)))).tolist())
# tstmdf[poscols] = pd.DataFrame(tstmdf['ImagePositionPatient']\
#               .apply(lambda x: list(map(float, ast.literal_eval(x)))).tolist())

trnmdf = trnmdf.sort_values(['SliceID']+poscols)\
                [['StudyInstanceUID', 'SliceID', 'SOPInstanceUID']+poscols].reset_index(drop=True)
# tstmdf = tstmdf.sort_values(['SliceID']+poscols)\
#                 [['SliceID', 'SOPInstanceUID']+poscols].reset_index(drop=True)

In [None]:
trnmdf.head()

In [None]:
trnmdf['seq'] = (trnmdf.groupby(['SliceID']).cumcount() + 1)
# tstmdf['seq'] = (tstmdf.groupby(['SliceID']).cumcount() + 1)

In [None]:
keepcols = ['StudyInstanceUID', 'SliceID', 'SOPInstanceUID', 'seq']
trnmdf = trnmdf[keepcols]
# tstmdf = tstmdf[keepcols]

In [None]:
%%time

train_emb_f = np.load("../input/extract-resnet18-features-for-seqeunce-model/emb_train_embdim512.npz")
valid_emb_f = np.load("../input/extract-resnet18-features-for-seqeunce-model/emb_valid_embdim512.npz")

train_emb = train_emb_f["embddings"]
trndf = pd.DataFrame({
    "StudyInstanceUID": train_emb_f["StudyInstanceUID"], 
    "SeriesInstanceUID": train_emb_f["SeriesInstanceUID"],
    "SOPInstanceUID": train_emb_f["SOPInstanceUID"]})

valid_emb = valid_emb_f["embddings"]
valdf = pd.DataFrame({
    "StudyInstanceUID": valid_emb_f["StudyInstanceUID"], 
    "SeriesInstanceUID": valid_emb_f["SeriesInstanceUID"],
    "SOPInstanceUID": valid_emb_f["SOPInstanceUID"]})

In [None]:
trndf = trndf[["SOPInstanceUID"]].merge(trnmdf, on="SOPInstanceUID", how="left")
valdf = valdf[["SOPInstanceUID"]].merge(trnmdf, on="SOPInstanceUID", how="left")

In [None]:
print(trndf.shape, valdf.shape)

In [None]:
trndf.head()

In [None]:
trndf['embidx'] = range(trndf.shape[0])
valdf['embidx'] = range(valdf.shape[0])

In [None]:
trndf.head()

In [None]:
train_df = pd.read_csv("../input/rsna-str-pulmonary-embolism-detection/train.csv")
# tstdf = pd.read_csv("../input/rsna-str-pulmonary-embolism-detection/test.csv")

In [None]:
trndf = trndf.merge(train_df.drop("StudyInstanceUID", 1), on="SOPInstanceUID", how="left")
valdf = valdf.merge(train_df.drop("StudyInstanceUID", 1), on="SOPInstanceUID", how="left")

In [None]:
print(trndf.shape)
print(valdf.shape)

In [None]:
trndf.head()

In [None]:
valdf.head()

In [None]:
del trnmdf, train_df
gc.collect()

In [None]:
class PEDataset(data.Dataset):
    def __init__(self, df, mat, labels=True):
        self.data = df
        self.mat = mat
        self.labels = labels
        self.patients = df.SliceID.unique()
        self.data = self.data.set_index('SliceID')

    def __len__(self):
        return len(self.patients)

    def __getitem__(self, idx):
        patidx = self.patients[idx]
        patdf = self.data.loc[patidx].sort_values('seq')
        patemb = self.mat[patdf['embidx'].values]

        patdeltalag  = np.zeros(patemb.shape)
        patdeltalead = np.zeros(patemb.shape)
        patdeltalag[1:] = patemb[1:] - patemb[:-1]
        patdeltalead[:-1] = patemb[:-1] - patemb[1:]

        patemb = np.concatenate((patemb, patdeltalag, patdeltalead), -1)
        # print(patemb.shape)
        
        ids = torch.tensor(patdf['embidx'].values)
        
        assert len(patemb) == len(ids), "emb size: {} id size: {}".format(len(patemb), len(ids))
        
        if self.labels:
            labels = torch.tensor(patdf[label_cols].values[0])
            return {'emb': patemb, 'embidx' : ids, 'labels': labels}    
        else:      
            return {'emb': patemb, 'embidx' : ids}

In [None]:
def collatefn(batch):
    maxlen = max([l['emb'].shape[0] for l in batch])
    embdim = batch[0]['emb'].shape[1]
    withlabel = 'labels' in batch[0]
#     if withlabel:
#         labdim = batch[0]['labels'].shape[1]
        
    for b in batch:
        masklen = maxlen-len(b['emb'])
        b['emb'] = np.vstack((np.zeros((masklen, embdim)), b['emb']))
        b['embidx'] = torch.cat((torch.ones((masklen),dtype=torch.long)*-1, b['embidx']))
        b['mask'] = np.ones((maxlen))
        b['mask'][:masklen] = 0.
#         if withlabel:
#             b['labels'] = np.vstack((np.zeros((maxlen-len(b['labels']), labdim)), b['labels']))
            
    outbatch = {'emb' : torch.tensor(np.vstack([np.expand_dims(b['emb'], 0) \
                                                for b in batch])).float()}  
    outbatch['mask'] = torch.tensor(np.vstack([np.expand_dims(b['mask'], 0) \
                                                for b in batch])).float()
    outbatch['embidx'] = torch.tensor(np.vstack([np.expand_dims(b['embidx'], 0) \
                                                for b in batch])).float()
    if withlabel:
        # outbatch['labels'] = torch.tensor(np.vstack([np.expand_dims(b['labels'], 0) for b in batch])).float()
        outbatch["labels"] = torch.tensor(np.vstack([b["labels"] for b in batch])).float()
    return outbatch

In [None]:
train_dataset = PEDataset(trndf, train_emb, labels=True)
train_loader = data.DataLoader(train_dataset, 
                               batch_size=batch_size, 
                               shuffle=True, 
                               num_workers=num_workers, 
                               collate_fn=collatefn)

valid_dataset = PEDataset(valdf, valid_emb, labels=True)
valid_loader = data.DataLoader(valid_dataset, 
                               batch_size=batch_size, 
                               shuffle=False, 
                               num_workers=num_workers, 
                               collate_fn=collatefn)

In [None]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [None]:
class PETransformerModel(nn.Module):
    def __init__(self, n_classes, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(PETransformerModel, self).__init__()
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, n_classes)
        
        self.init_weigths()
    
    def init_weigths(self):
        initrange = 0.1
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)
        
    def forward(self, x):
        x = x * math.sqrt(self.ninp)
        x = self.pos_encoder(x)
        hidden = self.transformer_encoder(x)
        print(hidden.size())
        output = self.decoder(hidden.mean(1))
        return output

In [None]:
nhid = 768
nlayers = 2
nhead = 2
dropout = 0.2

In [None]:
embed_size = train_emb.shape[-1] * 3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = NeuralNet(LSTM_UNITS=LSTM_UNITS, n_classes=n_classes)
model = PETransformerModel(n_classes, embed_size, nhead, nhid, nlayers, dropout=dropout)
model.to(device)

In [None]:
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
plist = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': DECAY},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

optimizer = optim.Adam(plist, lr=lr)
scheduler = StepLR(optimizer, 1, gamma=lrgamma, last_epoch=-1)
criterion = torch.nn.BCEWithLogitsLoss()

In [None]:
def validate(model, loader):
    valls = []
    current_loss_mean = 0.0
    model.eval()
    tqdm_loader = tqdm(loader)
    for step, batch in enumerate(tqdm_loader):
        mask = batch['mask'].to(device, dtype=torch.bool)
        inputs = batch["emb"]
        inputs = inputs.to(device, dtype=torch.float)
        y = batch['labels'].to(device, dtype=torch.float)
        
        logits = model(inputs)
        
        # get the mask for masked labels
        # maskidx = mask.view(-1) == 1
        
        # reshape for
        # logits = logits.view(-1, n_classes)[maskidx]
        valls.append(torch.sigmoid(logits).detach().cpu().numpy())
        
        loss = criterion(logits, y)
        
        current_loss_mean = (current_loss_mean * step + loss.item()) / (step + 1)
        tqdm_loader.set_description(f"validation loss : {current_loss_mean:.4}")
    
    return np.concatenate(valls, 0), current_loss_mean

In [None]:
%%time

for epoch in range(n_epochs):
    print("Epoch: {}".format(epoch + 1))
    current_loss_mean = 0.0
    tr_loss = 0.0
    model.train()
    tqdm_loader = tqdm(train_loader)
    for step, batch in enumerate(tqdm_loader):
        y = batch['labels'].to(device, dtype=torch.float)
        mask = batch['mask'].to(device, dtype=torch.bool)
        x = batch['emb'].to(device, dtype=torch.float)
        
        logits = model(x).to(device, dtype=torch.float)
        
        # get the mask for masked labels
        # maskidx = mask.view(-1) == 1
        # y = y.view(-1, n_classes)[maskidx]
        # logits = logits.view(-1, n_classes)[maskidx]
        
        # Get loss
        loss = criterion(logits, y)
        tr_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step % 100 == 0:
            print('Trn step {} of {} trn lossavg {:.5f}'. \
                  format(step, len(train_loader), (tr_loss / (1 + step))))
        
        current_loss_mean = (current_loss_mean * step + loss.item()) / (step + 1)
        tqdm_loader.set_description(f"train loss : {current_loss_mean:.4}")
    
    output_model_file = "transformer_epoch{}.pth".format(epoch + 1)
    torch.save(model.state_dict(), output_model_file)

    scheduler.step()
    
    logits, val_loss = validate(model, valid_loader)