In [1]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.models.feature_extraction import create_feature_extractor
import torch.nn.functional as F
import cv2

import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.tensorboard import SummaryWriter

import os
from tqdm.notebook import tqdm
import json

Loading model

In [2]:
from MyModel import MyResNet
from MyLoss import W_CEL
from count_ap import count_ap

In [4]:
# model = MyResNet(n_classes=12)
model = torch.load('cxr8_w_cl_map_129eps.pt')

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model.to(device)

Getting heatmap from output from transition layer & prediction layer weights

In [5]:
test_input = torch.rand(16, 3, 32, 32)
test_input = test_input.to(device)

In [6]:
# getting output from transition layer
return_nodes = {
    "transition": "transition"
}
model2 = create_feature_extractor(model, return_nodes=return_nodes)

intermediate_outputs = model2(test_input)

In [7]:
###### multiplication #####
# original shapes: torch.Size([8, 2048]) and torch.Size([16, 2048, 32, 32])
result = intermediate_outputs['transition'].permute(0, 2, 3, 1) @ torch.transpose(model.prediction.weight, 0, 1)

In [8]:
result.shape

torch.Size([16, 32, 32, 12])

---

In [6]:
%cd mimic_pa

/data/iasviridov/work/chest/detr_work/mimic/mimic_pa


In [7]:
df_train = pd.read_csv("mimic_train.csv")
df_val = pd.read_csv("mimic_val.csv")
df_test = pd.read_csv("mimic_test.csv")

In [8]:
%cd ..

/data/iasviridov/work/chest/detr_work/mimic


In [9]:
with open('mimic_id2label.json', 'r') as fp:
    id2label = json.load(fp)
id2label

{'0': 'Atelectasis',
 '1': 'Cardiomegaly',
 '2': 'Consolidation',
 '3': 'Edema',
 '4': 'Enlarged Cardiomediastinum',
 '5': 'Fracture',
 '6': 'Lung Lesion',
 '7': 'Lung Opacity',
 '8': 'Pleural Effusion',
 '9': 'Pleural Other',
 '10': 'Pneumonia',
 '11': 'Pneumothorax'}

In [10]:
class MimicDataset(Dataset):
    
    def __init__(self, main_root='/data/iasviridov/work/chest/detr_work/mimic/mimic_pa/', 
                 mode="train", transform=None):
        
        self.train = True if mode == "train" else False
        self.transform = transform
        self.root = main_root
        self.data = pd.read_csv(self.root + 'mimic_train.csv' if self.train else self.root + 'mimic_val.csv')
        
    def __len__(self):
        # return dataset length
        return len(self.data)
    
    def __getitem__(self, index):
        # load image
        try:
            sample = cv2.imread(self.root + self.data.loc[index, 'dicom_id'] + '.jpg')
            sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB)
        except:
            print(self.root + self.data.loc[index, 'dicom_id'] + '.jpg', index)
        
        # transform image
        sample = self.transform(sample)
        # normalization
        sample = sample.float() / 255
        # return dict with image and class label
        label = self.data.loc[index, self.data.columns[3:-1]].values.astype(None)

        return {
            "image": sample,
            "label": label, # int label
        }

In [11]:
class Transform:
    def __init__(
        self, hflip_prob: float = 0.4, ssr_prob: float = 0.4, random_bc_prob: float = 0.4, image_size = 32
    ):
        self.transform = A.Compose(
            [
                A.HorizontalFlip(p=hflip_prob),
                
                A.ShiftScaleRotate(
                    shift_limit=0.0625, scale_limit=0.1, rotate_limit=10, p=ssr_prob),
                
                A.RandomBrightnessContrast(p=random_bc_prob),
                
                A.Resize(height=image_size, width=image_size),
                
                A.pytorch.ToTensorV2()
            ]
        )

    def __call__(self, image):
        image = self.transform(image=image)["image"]
        return image

In [12]:
train_mimic = MimicDataset(mode="train", transform=Transform())
val_mimic = MimicDataset(mode="val", transform=Transform())

In [21]:
class Trainer:
    def __init__(self, model, optimizer, 
                 train_dataset, val_dataset, id2label, 
                 log_name=None, scheduler=None, 
                 batch_size=64, weights=True,
                 sampler=None,
                 num_workers = 16,
                 device=None, smooth_beta = 0.8
                ):
        
        self.optimizer = optimizer
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.num_workers = num_workers
        self.id2label = id2label

        self.batch_size = batch_size
        self.scheduler = scheduler

        self.device = device if device != None else torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
        self.model = model.to(self.device)
        
        self.loss = W_CEL()
        self.sampler = sampler

        if log_name == None:
            log_name = 'some_unknown_model'
            
        self.writer = SummaryWriter(log_dir = "./my_logs/" + log_name)
        
        self.smooth_beta = smooth_beta
        self.avg_train_loss = None
        self.avg_val_loss = None

    def train(self, num_epochs):
        
        model = self.model
        optimizer = self.optimizer
        
        if self.sampler is not None:
            if self.sampler == 'custom':
                train_sampler = VinDrSampler()
            elif self.sampler == 'weighted':
                train_sampler = torch.utils.data.WeightedRandomSampler(self.weights.type('torch.DoubleTensor'), 
                                                                           len(self.weights))
            ## shuffle = False! ###
            train_loader = DataLoader(self.train_dataset, shuffle=False, pin_memory=True, batch_size=self.batch_size
                                     ,sampler=train_sampler, num_workers = self.num_workers)
        else:
            train_loader = DataLoader(self.train_dataset, shuffle=True, pin_memory=True, batch_size=self.batch_size
                                     , num_workers = self.num_workers)
            
        val_loader = DataLoader(self.val_dataset, shuffle=False, pin_memory=True, batch_size=self.batch_size
                               , num_workers = self.num_workers)
        
        best_loss = float('inf')
        
        for epoch in range(num_epochs):
            
            print('---------- Epoch number {} started! -----------'.format(epoch+1))
            model.train()
            
            if self.scheduler != None:
                self.writer.add_scalar('Learning rate', optimizer.param_groups[0]["lr"], global_step=epoch)
                
            for batch in tqdm(train_loader):
                batch = {k: v.to(self.device) for k, v in batch.items()}
                
                x = batch['image']
                y = batch['label']

                logits = model(x)
                
                self.loss.beta_pos = y.shape[0] / y.sum()
                self.loss.beta_neg = y.shape[0] / (y.shape[0]**2 - y.sum())

                loss = self.loss.forward(logits, y)
                
                # calculate smoothed loss
                if self.avg_train_loss is None:
                    self.avg_train_loss = loss.item()
                else:
                    self.avg_train_loss = self.smooth_beta * self.avg_train_loss + (1 - self.smooth_beta) * loss.item()
                
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        
            self.writer.add_scalar('Train loss mean', self.avg_train_loss, global_step=epoch)
            self.avg_train_loss = None
            
            self.writer.add_scalar('Train mAP', count_ap(y, torch.sigmoid(logits), self.id2label), global_step=epoch)
            model.eval()

        
            for batch in tqdm(val_loader):
                batch = {k: v.to(self.device) for k, v in batch.items()}
                
                x = batch['image']
                y = batch['label']

                logits = model(x)

                self.loss.beta_pos = y.sum()
                self.loss.beta_neg = y.shape[0]**2 - y.sum()

                loss = self.loss.forward(logits, y)
                
                # calculate smoothed loss
                if self.avg_val_loss is None:
                    self.avg_val_loss = loss.item()
                else:
                    self.avg_val_loss = self.smooth_beta * self.avg_val_loss + (1 - self.smooth_beta) * loss.item()

   
            self.writer.add_scalar('Val loss mean', self.avg_val_loss, global_step=epoch)
            avg_val_loss = None
        
            self.writer.add_scalar('Val mAP', count_ap(y, torch.sigmoid(logits), self.id2label), global_step=epoch)
            

                
                
            if self.scheduler != None:
                self.scheduler.step(mean_loss)

In [22]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [23]:
trainer = Trainer(model, optimizer, train_mimic, val_mimic, id2label,
                  log_name = 'cxr8_w_cl_map_from2ep', 
                  batch_size=64,
                  device=device
                  )

In [24]:
trainer.train(num_epochs = 250)

---------- Epoch number 1 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 2 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 3 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 4 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 5 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 6 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 7 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 8 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 9 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 10 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 11 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 12 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 13 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 14 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 15 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 16 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 17 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 18 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 19 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 20 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 21 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 22 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 23 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 24 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 25 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 26 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 27 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 28 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 29 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 30 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 31 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 32 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 33 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 34 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 35 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 36 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 37 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 38 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 39 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 40 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 41 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 42 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 43 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 44 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 45 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 46 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 47 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 48 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 49 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 50 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 51 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 52 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 53 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 54 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 55 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 56 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 57 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 58 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 59 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 60 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 61 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 62 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 63 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 64 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 65 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 66 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 67 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 68 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 69 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 70 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 71 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 72 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 73 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 74 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 75 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 76 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 77 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 78 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

---------- Epoch number 79 started! -----------


  0%|          | 0/1071 [00:00<?, ?it/s]

/data/iasviridov/work/chest/detr_work/mimic/mimic_pa/32102641-d886c446-c2ccbbee-118f27a0-3ec9adff.jpg 39941/data/iasviridov/work/chest/detr_work/mimic/mimic_pa/afe39d3d-8f53bab9-c182f84e-43133325-dfdf8a74.jpg
/data/iasviridov/work/chest/detr_work/mimic/mimic_pa/597feb24-936d4232-1b4373ad-9f3a175f-ce047edd.jpg  6400142547
/data/iasviridov/work/chest/detr_work/mimic/mimic_pa/64b9eb3b-f34fa8a9-fc75f27f-53cc3b2c-eca0f252.jpg
 31383
/data/iasviridov/work/chest/detr_work/mimic/mimic_pa/cff72cf7-a50a1717-5122f706-591182d4-3c77d70b.jpg 65346
/data/iasviridov/work/chest/detr_work/mimic/mimic_pa/e404c257-197ae5cb-c2a7dd3c-03999252-6cf26079.jpg/data/iasviridov/work/chest/detr_work/mimic/mimic_pa/35dce161-4d03da08-38649914-0b6629fb-c8ab20fc.jpg 28046
 25278
/data/iasviridov/work/chest/detr_work/mimic/mimic_pa/05956af7-12abcc25-c86fdf23-d206ce82-a83e7163.jpg 65845
/data/iasviridov/work/chest/detr_work/mimic/mimic_pa/79776392-cf72dda9-c63b6d49-2b05bd17-c84f098c.jpg 14214


KeyboardInterrupt: 

In [25]:
torch.save(model, "cxr8_w_cl_map_129eps.pt")

---

Validation

In [78]:
df = pd.DataFrame({'prob': [], 'true': [], 'label': []})
df

Unnamed: 0,prob,true,label


In [79]:
filtered = []

for i in range(0,12):
    df_i = pd.DataFrame.from_dict({'prob': torch.sigmoid(model(test_input))[i,:].cpu().detach().numpy(), 
                     'true': torch.randint(0, 2, (16,12))[i,:].cpu().detach().numpy(), 
                     'label': list(id2label.values())}
                   )
    df = df.append(df_i)
    del df_i

In [80]:
for cls in id2label.values():
    print(average_precision_score(y_true=df[df['label']==cls]['true'].values,
                   y_score=df[df['label']==cls]['prob'].values))

0.49242424242424243
0.3409090909090909
0.63
0.8009700176366843
0.8125
0.6201704545454545
0.3646031746031746
0.83658810325477
0.5399659863945578
0.5806878306878307
0.5780612244897959
0.45568783068783075
