# Segmentation Using Autoencoder Post-Processing

In [None]:
import pandas as pd
import os
from tqdm import tqdm
import random
import itertools
from torchvision import transforms
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch import nn, optim
import time
import copy
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import segmentation_models_pytorch as smp
import timm
import math

In [None]:
data_file = os.listdir('./Myositis Images/CD27 Panel-- Component')
label_file = os.listdir('./Myositis Images/Labels/CD27 cell labels')
label_file.remove('Bounding Rectangle')
label_file.remove('Mask Labels')
mask_label_file = os.listdir('./Myositis Images/Labels/CD27 cell labels/Mask Labels')

In [None]:
selected_data = set([_[:-9] for _ in data_file]).intersection(set([_[:-11] for _ in label_file])).intersection(set([_[:-17] for _ in mask_label_file]))
print(len(selected_data))

In [None]:
DATASET_IMAGE_MEAN = (0.485, 0.456, 0.406)
DATASET_IMAGE_STD = (0.229, 0.224, 0.225)
transform_train = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomRotation(3),
    transforms.RandomVerticalFlip(),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(DATASET_IMAGE_MEAN, DATASET_IMAGE_STD), 
    ])
transform_val=transforms.Compose([transforms.ToPILImage(),transforms.ToTensor()])

transform_test=transforms.Compose([transforms.ToPILImage(),transforms.ToTensor()])

In [None]:
"""
To keep the results comparable we are using 
the same splits as Van Buren et al (https://www.sciencedirect.com/science/article/pii/S0022175922000205), 
and code at: https://github.com/tniewold/Artificial-Intelligence-and-Deep-Learning-to-Map-Immune-Cell-Types-in-Inflamed-Human-Tissue
"""

train_list = [[(_ + '_data_' + str(idx) + '.npy',  _ + '_mask_' + str(idx) + '.npy') for idx in range(12)] 
              for _ in selected_data if _[:13] in ['121919_Myo089', '121919_Myo253', '121919_Myo368']]
validation_list = [[(_ + '_data_' + str(idx) + '.npy', _ + '_mask_' + str(idx) + '.npy') for idx in range(12)] 
                   for _ in selected_data if _[:13] in ['121919_Myo208', '121919_Myo388']]
test_list = [[(_ + '_data_' + str(idx) + '.npy', _ + '_mask_' + str(idx) + '.npy') for idx in range(12)] 
             for _ in selected_data if _[:13] in ['121919_Myo231', '121919_Myo511']]
train_list = list(itertools.chain(*train_list))
validation_list = list(itertools.chain(*validation_list))
test_list = list(itertools.chain(*test_list))
print('train data: {}'.format(len(train_list)))
print('validation data: {}'.format(len(validation_list)))
print('test data: {}'.format(len(test_list)))

In [None]:
train_list[0]

In [None]:
class CustomDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform
  

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        inputs = np.load('./New data/New Image Tile/image/' + (self.file_list[idx][0]),allow_pickle=True)
        bdry = np.load('./New data/New Image Tile/labels/' + (self.file_list[idx][1]),allow_pickle=True)
        mask_label = np.load('./New data/New Image Tile/labels/' + (self.file_list[idx][1]),allow_pickle=True)
        mask_label = mask_label / 255
        
        
        
        inputs = torch.from_numpy(inputs).unsqueeze(0)
        bdry = torch.from_numpy(bdry).unsqueeze(0)
        mask_label = torch.from_numpy(mask_label).unsqueeze(0)
        label = torch.max(bdry * 2, mask_label).long().squeeze()

        if self.transform:
            inputs = self.transform(inputs)
            mask_label = self.transform(mask_label)

        return (inputs, mask_label)

In [None]:
dataloader = {}
dataloader['train'] = DataLoader(CustomDataset(train_list,transform=transform_train), batch_size=16, shuffle=True, num_workers=8, drop_last=False)
dataloader['validation'] = DataLoader(CustomDataset(validation_list,transform=transform_val), batch_size=16, shuffle=False, num_workers=8, drop_last=False)
dataloader['test'] = DataLoader(CustomDataset(test_list,transform=transform_test), batch_size=16, shuffle=False, num_workers=8, drop_last=True)

In [None]:
datasize = {'train': len(train_list), 'validation': len(validation_list), 'test': len(test_list)}
datasize

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # encoder
        self.enc1 = nn.Linear(in_features=480, out_features=15360) 
        self.encmid = nn.Linear(in_features=15360, out_features=256)
        self.enc2 = nn.Linear(in_features=256, out_features=128)
        self.enc3 = nn.Linear(in_features=128, out_features=64)
        self.enc4 = nn.Linear(in_features=64, out_features=32)
        self.enc5 = nn.Linear(in_features=32, out_features=16)

        # decoder 
        self.dec1 = nn.Linear(in_features=16, out_features=32)
        self.dec2 = nn.Linear(in_features=32, out_features=64)
        self.dec3 = nn.Linear(in_features=64, out_features=128)
        self.dec4 = nn.Linear(in_features=128, out_features=256)
        self.dec5 = nn.Linear(in_features=256, out_features=480) 

    def forward(self, x):
        x = F.relu(self.enc1(x))
        x = F.relu(self.encmid(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        x = F.relu(self.enc5(x))

        x = F.relu(self.dec1(x))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        x = F.relu(self.dec4(x))
        x = F.relu(self.dec5(x))
        return x

In [None]:
train_acc=[]
train_acc_mask=[]
train_acc_boundary=[]
train_acc_bg=[]
val_acc=[]
val_acc_mask=[]
val_acc_boundary=[]
val_acc_bg=[]
iou_q_val=[]
iou_q_t=[]


def train_model(model, criterion, optimizer, scheduler,encoder_name, autoencoder,criterion_ae,optimizer_ae,num_epochs=50, maximum_patient=5):
    since = time.time()
    best_iou = 0.0
    best_loss = 0.0
    total_loss=0
    prev=math.inf
    best=0
    counter=0
    last_val=0
    for epoch in tqdm(range(num_epochs)):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        for inputs, mask in dataloader['train']:
            model.train()
            ae.train()
            inputs = inputs.to(device)
            mask = mask.squeeze().to(device)
            outputs = model(inputs)
            outputs_ae=ae(outputs)
            
            pred = torch.argmax(outputs, dim=1)
            loss1 = criterion(outputs, mask.long())
            loss2 = criterion_ae(outputs_ae, mask.long())
            loss=loss1+loss2
            loss.backward()
            optimizer.step()
            optimizer_ae.step()
            optimizer_ae.zero_grad()
            optimizer.zero_grad()
            scheduler.step()
            total_loss+=loss.item()
        avg_loss = total_loss/len(dataloader['train'])
        tp1, fp1, fn1, tn1 = smp.metrics.get_stats(pred==1, mask==1, mode='multilabel', threshold=0.5)
        tp2, fp2, fn2, tn2 = smp.metrics.get_stats(pred==0, mask==0, mode='multilabel', threshold=0.5)
        mask_accuracy = smp.metrics.accuracy(tp1, fp1, fn1, tn1, reduction="macro")
        print("Mask Accuracy:", mask_accuracy)
        bg_accuracy = smp.metrics.accuracy(tp2, fp2, fn2, tn2, reduction="macro")
        print("Background Accuracy:", bg_accuracy)
        tp, fp, fn, tn = smp.metrics.get_stats(pred, mask.long(), mode='multilabel', threshold=0.5)
        iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
        accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
        print("overall Acc:",accuracy)
            
            
            

        #print('{} Loss: {:.4f}  Mask Acc: {:.4f} Background Acc: {:.4f} Overall Acc: {:.4f}'.format(phase, epoch_loss, epoch_mask, epoch_bkgd, epoch_acc))
        print("Appending to Train List")
        train_acc.append(accuracy)
        train_acc_mask.append(mask_accuracy)
        #train_acc_boundary.append(epoch_bdry)
        train_acc_bg.append(bg_accuracy)
        print('IoU:',iou_score)
        epoch_iou=iou_score
        iou_q_t.append(iou_score)
        if epoch_iou > best:
            #Validating
            with torch.no_grad():
                model.eval()
                total_loss=0
                print('Validation')
                print('-' * 20)
                for inputs, mask in dataloader['validation']:
                    inputs = inputs.to(device)
                    mask = mask.squeeze().to(device)
                    outputs = model(inputs)
                    pred = torch.argmax(outputs, dim=1)
                    loss = criterion(outputs, mask.long())
                    total_loss+=loss.item()
                avg_loss=total_loss/len(dataloader['validation'])
            
                tp1, fp1, fn1, tn1 = smp.metrics.get_stats(pred==1, mask==1, mode='multilabel', threshold=0.5)
                tp2, fp2, fn2, tn2 = smp.metrics.get_stats(pred==0, mask==0, mode='multilabel', threshold=0.5)
                mask_accuracy = smp.metrics.accuracy(tp1, fp1, fn1, tn1, reduction="macro")
                print("Mask Accuracy:", mask_accuracy)
                bg_accuracy = smp.metrics.accuracy(tp2, fp2, fn2, tn2, reduction="macro")
                print("Background Accuracy:", bg_accuracy)
                tp, fp, fn, tn = smp.metrics.get_stats(pred, mask.long(), mode='multilabel', threshold=0.5)
                iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
                accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
                print("overall Acc:",accuracy)
                print("Validation IoU:",iou_score)
                if iou_score > last_val:
                    last_val=iou_score
                    print("Saving Model!")
                    torch.save(model, 
                               './Saved_models/Unet-ae-{}-scse-depth3-imgnet-es.pt'.format(encoder_name))
                
                
            if avg_loss > prev:
                counter+=1
            else:
                counter = 0
                
                
            prev=avg_loss
            
            if counter > 5:
                print("ES!!!!!!")

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Training and validation Loop

In [None]:
#Training Unet with timm-efficientnet-b0 backbone
encoder_name='timm-efficientnet-b0'
model = smp.Unet(encoder_name=encoder_name,decoder_attention_type='scse',encoder_weights='imagenet',in_channels=1, classes=2,encoder_depth=3,decoder_channels=(256, 128, 64))
ae = Autoencoder()
ae = ae.to(device)
criterion_ae = nn.MSELoss()
optimizer_ae = optim.Adam(ae.parameters(), lr=1e-3)
model = model.to(device)
criterion = nn.CrossEntropyLoss(reduction='sum', weight=torch.tensor([0.1479139275021023,0.8520860724978977]).to(device))
optimizer = optim.Adam(model.parameters(), lr=3.6e-04, weight_decay=1e-05)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 50, eta_min= 3.4e-04, last_epoch=- 1)
train_model(model, criterion, optimizer, scheduler,encoder_name,autoencoder = ae,criterion_ae=criterion_ae,optimizer_ae=optimizer_ae)

# Testing on Test Set

In [None]:
model = torch.load('./Saved_models/Unet-ae-timm-efficientnet-b0-scse-depth3-imgnet-es.pt')
model = model.to(device)
model.eval()
tp_total=0
fp_total=0
fn_total=0
tn_total=0
for inputs, mask in dataloader['test']:
    inputs=inputs.to(device)
    mask=mask.to(device)
    outputs = model(inputs)
    model.to(device)
    mask = mask.squeeze().to(device)
    pred = torch.argmax(outputs, dim=1)
    tp, fp, fn, tn = smp.metrics.get_stats(pred, mask.long(), mode='multilabel', threshold=0.5)
    iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
    accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
    tp_total+=tp
    fp_total+=fp
    fn_total+=fn
    tn_total+=tn
print('Overall Scores:')
iou_score = smp.metrics.iou_score(tp_total, fp_total, fn_total, tn_total, reduction="micro")
accuracy = smp.metrics.accuracy(tp_total, fp_total, fn_total, tn_total, reduction="macro")
print("Pixel Acc:",accuracy)
print("IoU:",iou_score)