In [None]:
!nvidia-smi

Wed Nov 13 03:34:11 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.50       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P0    29W / 250W |      0MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

In [None]:
!git clone https://github.com/imesluh/vocalfolds.git
!git clone https://github.com/meetshah1995/pytorch-semseg.git
!mv pytorch-semseg/ptsemseg ptsemseg

Cloning into 'vocalfolds'...
remote: Enumerating objects: 1181, done.[K
remote: Total 1181 (delta 0), reused 0 (delta 0), pack-reused 1181[K
Receiving objects: 100% (1181/1181), 204.46 MiB | 50.98 MiB/s, done.
Resolving deltas: 100% (45/45), done.
Cloning into 'pytorch-semseg'...
remote: Enumerating objects: 1088, done.[K
remote: Total 1088 (delta 0), reused 0 (delta 0), pack-reused 1088[K
Receiving objects: 100% (1088/1088), 277.32 KiB | 6.76 MiB/s, done.
Resolving deltas: 100% (738/738), done.


In [None]:
import os
import numpy as np
from numpy.random import permutation

'''
Train-Valid-Test split.
For each sequence directory, randomly permutes the (img, annot) pairs and
split them into train, valid, test data.
'''

train_ratio = 0.5
valid_ratio = 0.25
train_idxs = []
valid_idxs = []
test_idxs = []
for j in range(2):
    for i in range(4):
        seq_size = len(os.listdir('vocalfolds/img/patient{}/seq{}'.format(j+1, j*4+i+1)))
        train_cut = int(seq_size*train_ratio)
        valid_cut = train_cut + int(seq_size*valid_ratio)
        perm = permutation(seq_size)
        idxs = np.arange(seq_size)
        train_idxs.append(idxs[perm][:train_cut])
        valid_idxs.append(idxs[perm][train_cut:valid_cut])
        test_idxs.append(idxs[perm][valid_cut:])

In [None]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

'''
Custom Dataset for Vocal Folds data.
Gets an input list which contains 8 list of indexes for each sequence
and returns (img, annot) pairs where img, annot are (tensor, patient info, seq info) triples.
'''

class VFDataset(Dataset):
    def __init__(self, idxs_list, img_size=512):
        self.imgs = []
        self.annots = []
        trans = transforms.ToTensor()
        for s in range(8):
            if s < 4:
                pt = 1
            else:
                pt = 2
            img_dir = 'vocalfolds/img/patient{}/seq{}'.format(pt, s+1)
            annot_dir = 'vocalfolds/annot/patient{}/seq{}'.format(pt, s+1)
            seq_list = sorted(list(os.listdir(img_dir)))
            for idx in idxs_list[s]:
                img = Image.open(os.path.join(img_dir, seq_list[idx]))
                annot = Image.open(os.path.join(annot_dir, seq_list[idx]))

                img = trans(img)
                annot = trans(annot).mul(255).view(img_size, img_size)

                self.imgs.append((img, 'patient{}'.format(pt), 'seq{}'.format(s+1), seq_list[idx]))
                self.annots.append((annot, 'patient{}'.format(pt), 'seq{}'.format(s+1), seq_list[idx]))

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        return self.imgs[idx], self.annots[idx]

train_dataset = VFDataset(train_idxs)
valid_dataset = VFDataset(valid_idxs)
test_dataset = VFDataset(test_idxs)

In [None]:
from torch.nn.functional import one_hot

'''
Computes class weights based on their relative frequency.
'''

num_classes = 7
class_occurs = torch.zeros(num_classes)
for _, annot in train_dataset:
    y = one_hot(annot[0].long(), num_classes).view(-1, num_classes).sum(dim=0)
    class_occurs += y
for _, annot in valid_dataset:
    y = one_hot(annot[0].long(), num_classes).view(-1, num_classes).sum(dim=0)
    class_occurs += y
total_occurs = class_occurs.sum()
class_weights = torch.tensor([total_occurs / (class_occurs[i]*num_classes) for i in range(num_classes)])
class_weights /= class_weights.sum()
print(class_weights)

tensor([0.0595, 0.0058, 0.0073, 0.0262, 0.8542, 0.0103, 0.0367])


In [None]:
train_loader = DataLoader(train_dataset, batch_size=6, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=6, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=6, shuffle=False)

In [None]:
img, annot = iter(train_loader).next()
x = img[0]
y = annot[0]
print(x.size(), y.size())

torch.Size([6, 3, 512, 512]) torch.Size([6, 512, 512])


In [None]:
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

save_epoch = 50 # save model every save_epoch

class Trainer:
    def __init__(self, model, train_loader, valid_loader, class_weights, device, num_classes=7, img_size=512):
        self.model = model
        self.optimizer = Adam(self.model.parameters(), lr=0.001)
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=0.1)
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.device = device
        self.class_weights = class_weights
        self.criterion = CrossEntropyLoss(weight=self.class_weights.to(device))
        self.num_classes = num_classes
        self.img_size = img_size

        self.load_model()
    
    def train(self, n_epochs):
        sum_loss = 0.0
        N = 0
        for epoch in range(n_epochs):
            with tqdm(self.train_loader) as pbar:
                for img, annot in pbar:
                    x = img[0].to(self.device) # tensors are in img[0] and annot[0]
                    y = annot[0].to(self.device).long()
                    y_pred = self.model(x)

                    loss = self.criterion(y_pred, y)
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                    N += x.size(0)
                    sum_loss += (loss.item() * x.size(0))
                    pbar.set_description('epoch {}: {:.3f}'.format(epoch, sum_loss / N))
            
            accuracy, precisions, recalls, IoUs = self.eval(self.valid_loader)
            self.scheduler.step(IoUs.mean())
            if (epoch+1) % save_epoch == 0:
                self.save_model(epoch)
    
    def eval(self, testloader):
        with torch.no_grad():
            with tqdm(testloader) as pbar:
                true_positives = torch.zeros(self.num_classes)
                false_positives = torch.zeros(self.num_classes)
                false_negatives = torch.zeros(self.num_classes)
                class_counts = torch.zeros(self.num_classes)
                class_pred_counts = torch.zeros(self.num_classes)
                total = 0
                for img, annot in pbar:
                    x = img[0].to(self.device)
                    y = annot[0].to(self.device).long()
                    y_pred = self.model(x).permute(0, 2, 3, 1)
                    y_pred = y_pred.argmax(dim=3).view(-1, self.img_size, self.img_size)

                    # Compute classification metrics.
                    for c in range(self.num_classes):
                        y_c = y.eq(c)
                        y_nc = y.ne(c)
                        y_pred_c = y_pred.eq(c)
                        y_pred_nc = y_pred.ne(c)

                        true_positives[c] += (y_c*y_pred_c).sum()
                        false_positives[c] += (y_nc*y_pred_c).sum()
                        false_negatives[c] += (y_c*y_pred_nc).sum()
                        class_counts[c] += y_c.sum()
                        class_pred_counts[c] += y_pred_c.sum()

                    accuracy = (y == y_pred).sum().float() / (y == y).sum().float()

                    precisions = true_positives / class_pred_counts
                    recalls = true_positives / class_counts
                    IoUs = true_positives / (true_positives + false_positives + false_negatives)

                    IoU_mean = IoUs.mean()

                    pbar.set_description('accuracy {:.3f}, avg IoU {:.3f}'.format(accuracy, IoU_mean))

        return accuracy, precisions, recalls, IoUs

    def save_model(self, epoch):
        state = {'model': self.model.state_dict(),
                 'optimizer': self.optimizer.state_dict(),
                 'scheduler': self.scheduler.state_dict()}
        torch.save(state, '{}.ckpt'.format(epoch))
        torch.save(state, 'last.ckpt'.format(epoch))

    def load_model(self, ckpt_name=None):
        if ckpt_name is None:
            ckpt_name = 'last.ckpt'

        if os.path.exists(ckpt_name):
            state = torch.load(ckpt_name)
            self.model.load_state_dict(state['model'])
            self.optimizer.load_state_dict(state['optimizer'])
            self.scheduler.load_state_dict(state['scheduler'])
            print('model loaded from {}'.format(ckpt_name))

In [None]:
from ptsemseg.models.segnet import segnet

device = torch.device('cuda')
model = segnet(n_classes=7).to(device)

In [None]:
trainer = Trainer(model, train_loader, valid_loader, class_weights, device)
trainer.train(200)

epoch 0: 1.839: 100%|██████████| 45/45 [00:26<00:00,  1.87it/s]
accuracy 0.410, avg IoU 0.171: 100%|██████████| 22/22 [00:04<00:00,  4.66it/s]
epoch 1: 1.686: 100%|██████████| 45/45 [00:26<00:00,  1.88it/s]
accuracy 0.438, avg IoU 0.214: 100%|██████████| 22/22 [00:04<00:00,  4.64it/s]
epoch 2: 1.589: 100%|██████████| 45/45 [00:26<00:00,  1.88it/s]
accuracy 0.492, avg IoU 0.248: 100%|██████████| 22/22 [00:04<00:00,  4.66it/s]
epoch 3: 1.524: 100%|██████████| 45/45 [00:26<00:00,  1.88it/s]
accuracy 0.543, avg IoU 0.267: 100%|██████████| 22/22 [00:04<00:00,  4.65it/s]
epoch 4: 1.465: 100%|██████████| 45/45 [00:26<00:00,  1.87it/s]
accuracy 0.563, avg IoU 0.298: 100%|██████████| 22/22 [00:04<00:00,  4.66it/s]
epoch 5: 1.425: 100%|██████████| 45/45 [00:26<00:00,  1.88it/s]
accuracy 0.605, avg IoU 0.304: 100%|██████████| 22/22 [00:04<00:00,  4.67it/s]
epoch 6: 1.392: 100%|██████████| 45/45 [00:26<00:00,  1.87it/s]
accuracy 0.554, avg IoU 0.309: 100%|██████████| 22/22 [00:04<00:00,  4.64it/s]

In [None]:
accuracy, precisions, recalls, IoUs = trainer.eval(test_loader)
print(accuracy)
print(precisions)
print(recalls)
print(IoUs)

accuracy 0.897, avg IoU 0.708: 100%|██████████| 23/23 [00:05<00:00,  4.66it/s]


tensor(0.8974, device='cuda:0')
tensor([0.5729, 0.9222, 0.8632, 0.8638, 0.8321, 0.8966, 0.7524])
tensor([0.5580, 0.8464, 0.9187, 0.8068, 0.8155, 0.9290, 0.9400])
tensor([0.3941, 0.7900, 0.8020, 0.7158, 0.7002, 0.8392, 0.7179])


In [None]:
print(precisions.mean(), recalls.mean())

tensor(0.8147) tensor(0.8306)


In [None]:
class_names = ['void', 'vocal folds', 'other tissue', 'glottal space', 'pathology', 'surgical tool', 'intubation']
print("{:20}\t{:>10}\t{:>10}\t{:>10}\t{:>10}".format('<Class Name>', 'IoU', 'Precision', 'Recall', 'F1-Score'))
for c in range(7):
    n = class_names[c]
    i = IoUs[c]
    p = precisions[c]
    r = recalls[c]
    f = p*r / (p+r)
    print("{:20}\t{:10.3f}\t{:10.3f}\t{:10.3f}\t{:10.3f}".format(n, i, p, r, f))

<Class Name>        	       IoU	 Precision	    Recall	  F1-Score
void                	     0.394	     0.573	     0.558	     0.283
vocal folds         	     0.790	     0.922	     0.846	     0.441
other tissue        	     0.802	     0.863	     0.919	     0.445
glottal space       	     0.716	     0.864	     0.807	     0.417
pathology           	     0.700	     0.832	     0.815	     0.412
surgical tool       	     0.839	     0.897	     0.929	     0.456
intubation          	     0.718	     0.752	     0.940	     0.418


In [None]:
'''
Define rgb values for each class and a function which maps an annotation into segmentation map.
'''

label_colors = np.array([(128, 128, 128),    # gray, void
                         (255, 0, 0),        # red, vocal folds
                         (0, 0, 255),        # blue, other tissue
                         (0, 255, 0),        # green, glottal space
                         (127, 0, 255),      # purple, pathology
                         (255, 128, 0),      # orange, surgical tool
                         (255, 255, 0)],      # yellow, intubation
                        dtype=np.float32
                        )

def decode_segmap(image, num_classes=7): 
  r = np.zeros_like(image)
  g = np.zeros_like(image)
  b = np.zeros_like(image)
   
  for l in range(num_classes):
    idx = image == l
    r[idx] = label_colors[l, 0] / 255
    g[idx] = label_colors[l, 1] / 255
    b[idx] = label_colors[l, 2] / 255
     
  rgb = np.stack([r, g, b], axis=2)
  return torch.tensor(rgb).permute(2, 0, 1)

In [None]:
from torchvision.utils import save_image

'''
Randomly choose sample from test dataset and visualize them.
'''

sample_loader = DataLoader(test_dataset, batch_size=6, shuffle=True)

with torch.no_grad():
    for count, (img, annot) in enumerate(sample_loader):
        # create count samples.
        if count > 3:
            break

        real_images = []
        ground_truths = []
        predictions = []
        
        for i in range(6):
            x = img[0][i].to(device).unsqueeze(0)
            y = annot[0][i]
            y_pred = trainer.model(x).permute(0, 2, 3, 1).squeeze()
            y_pred = y_pred.argmax(dim=2).view(-1, trainer.img_size, trainer.img_size)
            y_pred = y_pred.cpu().squeeze().float()

            real_images.append(x.cpu())
            ground_truths.append(decode_segmap(y).unsqueeze(0))
            predictions.append(decode_segmap(y_pred).unsqueeze(0))

        real_images = torch.cat(real_images, dim=0)
        ground_truths = torch.cat(ground_truths, dim=0)
        predictions = torch.cat(predictions, dim=0)
        save_image(real_images, 'real_images_{}.png'.format(count+1))
        save_image(ground_truths, 'ground_truths_{}.png'.format(count+1))
        save_image(predictions, 'predictions_{}.png'.format(count+1))