# Semi-supervised Convolutional Auxiliary Deep Generative Model

In [None]:
# Imports
import torch
from IPython.display import clear_output
from glob import glob
import pandas as pd
cuda = torch.cuda.is_available()
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from imgaug import augmenters as iaa
%matplotlib inline
import sys
import os
import pydicom, numpy as np
from skimage.transform import resize
sys.path.append("../semi-supervised-pytorch-master/semi-supervised") # path to models
det_class_path = '../Kaggle/all/stage_2_detailed_class_info.csv' # class info
bbox_path = '../Kaggle/all/stage_2_train_labels.csv' # labels
dicom_dir = '../Kaggle/all/stage_2_train_images/' # train images

Here we show the image-level labels for the scans. The most interesting group here is the No Lung Opacity / Not Normal since they are cases that look like opacity but are not. The classes are balanced so we don't need to cope to unbalanced classes problem. 

In [None]:
det_class_df = pd.read_csv(det_class_path)
print(det_class_df.shape[0], 'class infos loaded')
print(det_class_df['patientId'].value_counts().shape[0], 'patient cases')
det_class_df.groupby('class').size().plot.bar()

In [None]:
# Some useful functions
def indices_to_one_hot(data, nb_classes):
    """Convert an iterable of indices to one-hot encoded labels."""
    targets = np.array(data).reshape(-1)
    return np.eye(nb_classes)[targets]

def batch(iterable, n=1):
    """Return a batch from the iterable"""
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

In [None]:
# create training dataset with labelled and unlabelled images
image_df = pd.DataFrame({'path': glob(os.path.join(dicom_dir, '*.dcm'))})
image_df['patientId'] = image_df['path'].map(lambda x: os.path.splitext(os.path.basename(x))[0])

# training/validation slit
validation = 0.1 
# image resize (the bigger the better, except for computational power ...)
image_resize = 225 
# number of unlabelled images in the training dataset
labelled_images = 1000
# number of labelled ones (the rest of the training dataset)
unlabelled_images = int(image_df.shape[0]*(1-validation)-labelled_images) 
# number of validation images
validation_images = int(validation*image_df.shape[0])
# Some list for the dataset probably I should use something faster ...
labelled = []
label = []
unlabelled = []
# We don't need the same subject repated when multiple bounding boxes occur
det_class_df.drop_duplicates()
allLabel = pd.get_dummies(pd.Series(list(det_class_df['class']))).values
label0Count = 0
label1Count = 0
label2Count = 0
labelIndex = []
finishLabelling = False

# Prepare training dataset
i = 0
done = 0
while(not finishLabelling):
    if allLabel[i][0] == 1 and label0Count < labelled_images/3:
        done += 1
        label0Count += 1
        labelIndex.append(i)
        k = np.where(image_df['patientId'] == det_class_df['patientId'][i])
        labelled.append(resize(pydicom.read_file(image_df['path'].values[k[0]][0]).pixel_array/255, 
                            (image_resize,image_resize), anti_aliasing=True, mode='constant'))
        label.append(allLabel[i])
    elif allLabel[i][1] == 1 and label1Count < labelled_images/3:
        done += 1        
        label1Count += 1
        labelIndex.append(i)
        k = np.where(image_df['patientId'] == det_class_df['patientId'][i])
        labelled.append(resize(pydicom.read_file(image_df['path'].values[k[0]][0]).pixel_array/255, 
                            (image_resize,image_resize), anti_aliasing=True, mode='constant'))
        label.append(allLabel[i])
    elif allLabel[i][2] == 1 and label2Count < labelled_images/3:
        done += 1
        label2Count += 1
        labelIndex.append(i)
        k = np.where(image_df['patientId'] == det_class_df['patientId'][i])
        labelled.append(resize(pydicom.read_file(image_df['path'].values[k[0]][0]).pixel_array/255, 
                            (image_resize,image_resize), anti_aliasing=True, mode='constant'))
        label.append(allLabel[i])
    if label0Count == labelled_images/3 and label1Count == labelled_images/3 and label2Count == labelled_images/3:
        finishLabelling = True
    i += 1
    if done % 1000 == 0:
        print(str(done) + ' labelled images out of ' + str(labelled_images) + ' done')

print(str(labelled_images) + ' training images labelled loaded')

done = 0
for i in range(labelled_images + unlabelled_images):
    if i not in labelIndex:
        done += 1
        labelIndex.append(i)
        k = np.where(image_df['patientId'] == det_class_df['patientId'][i])
        unlabelled.append(resize(pydicom.read_file(image_df['path'].values[k[0]][0]).pixel_array/255,
                                (image_resize,image_resize), anti_aliasing=True, mode='constant'))
        if done % 1000 == 0 and done != 0:
            print(str(done) + ' unlabelled images out of ' + str(unlabelled_images) + ' done')

print(str(unlabelled_images) + ' training images unlabelled loaded')

# Prepare validation dataset
labelled_val = []
label_val = []
done = 0
for i in range(int(image_df.shape[0])):
    if i not in labelIndex:
        done += 1
        label_val.append(allLabel[i])
        k = np.where(image_df['patientId'] == det_class_df['patientId'][i])
        labelled_val.append(resize(pydicom.read_file(image_df['path'].values[k[0]][0]).pixel_array/255, 
                        (image_resize,image_resize), anti_aliasing=True, mode='constant'))
        if done % 1000 == 0:
            print(str(done) + ' images out of ' + str(validation_images) + ' done')

print('Validation images loaded')

trainNbr = np.sum(label, axis=0)
valNbr = np.sum(label_val, axis=0)

print('Summary:')

print('Training images: ' + str(labelled_images + unlabelled_images))
print('Labelled: ' + str(labelled_images) + ', Unlabelled: ' + str(unlabelled_images))
print('Labels: Opacity ' + str(trainNbr[0]) + ', Not-normal ' + str(trainNbr[1]) + ', Normal ' + str(trainNbr[2]))

print('Validation images: ' + str(validation_images))
print('Labels: Opacity ' + str(valNbr[0]) + ', Not-normal ' + str(valNbr[1]) + ', Normal ' + str(valNbr[2]))


## Auxiliary Deep Generative Model

The Auxiliary Deep Generative Model [[Maaløe, 2016]](https://arxiv.org/abs/1602.05473) posits a model that with an auxiliary latent variable $a$ that infers the variables $z$ and $y$. This helps in terms of semi-supervised learning by delegating causality to their respective variables. 

We create the model architecture

In [None]:
from models import AuxiliaryDeepGenerativeModel

y_dim = 3
z_dim = 128
a_dim = 128
h_dim = [2048, 1024, 512, 256]

model = AuxiliaryDeepGenerativeModel([image_resize*image_resize, y_dim, z_dim, a_dim, h_dim])
print(model)

In [None]:
from itertools import cycle
from inference import SVI, DeterministicWarmup, log_gaussian

# We will need to use warm-up in order to achieve good performance.
# Over 200 calls to SVI we change the autoencoder from
# deterministic to stochastic.

def log_gauss(x, mu, var):
    return -(log_gaussian(x, mu, var))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=1e-4)
beta = DeterministicWarmup(n=100)
beta_constant = 0.1
alpha = beta_constant * (len(unlabelled) + len(labelled)) / len(labelled)


if cuda: model = model.cuda()
elbo = SVI(model, likelihood=log_gauss, beta=beta)

The library is conventially packed with the `SVI` method that does all of the work of calculating the lower bound for both labelled and unlabelled data depending on whether the label is given. It also manages to perform the enumeration of all the labels.

Remember that the labels have to be in a *one-hot encoded* format in order to work with SVI.

In [None]:
from torch.autograd import Variable
from sklearn.decomposition import PCA
import random

n_epochs = 100
batchSize= 15

# Some variables for plotting losses
accuracyTrain = []
accuracyVal = []
LTrain = []
LVal = []
UTrain = []
UVal = []
classTrain = []
classVal = []
JAlphaTrain = []
JAlphaVal = []
image_augmenter = iaa.SomeOf((1, None),[iaa.Fliplr(0.5),
                                        iaa.Affine(scale=(0.8, 1.2),
                                        translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05)},
                                        rotate=(-15, 15))
                                        ],random_order=True,)

for epoch in range(n_epochs):
    model.train()
    total_L_train, total_U_train, total_classification_loss_train, total_loss_train, accuracy_train = (0, 0, 0, 0, 0)
    total_L_val, total_U_val, total_classification_loss_val, total_loss_val, accuracy_val = (0, 0, 0, 0, 0)
    m_train, m_val = (0, 0)
    
    # Shuffle the data every epoch (labelled and label should keep the same index ordering!)
    z = list(zip(labelled, label))
    random.shuffle(z)
    random.shuffle(unlabelled)
    labelled, label = zip(*z)
    latent = []
    y_pred = []
    for x, y, u in zip(cycle(batch(labelled, batchSize)), cycle(batch(label, batchSize)), (batch(unlabelled, batchSize))):
        m_train+=1
        x = image_augmenter.augment_images(x)
        u = image_augmenter.augment_images(u)

        # Wrap in variables
        x, y, u = torch.from_numpy(np.asarray(x).reshape(-1, image_resize*image_resize)), torch.Tensor(y), torch.from_numpy(np.asarray(u).reshape(-1, image_resize*image_resize))
        x, y, u = x.type(torch.FloatTensor), y.type(torch.FloatTensor), u.type(torch.FloatTensor)

        if cuda:
            # They need to be on the same device and be synchronized.
            x, y = x.cuda(device=0), y.cuda(device=0)
            u = u.cuda(device=0)

        L, _ = elbo(x, y) 
        U, _ = elbo(u)

        # Add auxiliary classification loss q(y|x)
        logits = model.classify(x)

        # Regular cross entropy
        classication_loss = - torch.sum(y * torch.log(logits + 1e-8), dim=1).mean()
        J_alpha_train = - L + alpha * classication_loss - U
        
        J_alpha_train.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_L_train += L.item()
        total_U_train += U.item()
        total_classification_loss_train += classication_loss.item()
        total_loss_train += J_alpha_train.item()
        accuracy_train += torch.mean((torch.max(logits, 1)[1].data == torch.max(y, 1)[1].data).float())
        
    model.eval()
    for x, y in zip(batch(labelled_val, batchSize), batch(label_val, batchSize)):
        m_val+=1
        x, y = torch.from_numpy(np.asarray(x).reshape(-1, image_resize*image_resize)), torch.Tensor(y)
        x, y = x.type(torch.FloatTensor), y.type(torch.FloatTensor)

        if cuda:
            x, y = x.cuda(device=0), y.cuda(device=0)

        L, z = elbo(x, y) 
        U, _ = elbo(x)
        latent.append(z.cpu().detach().numpy())

        logits = model.classify(x)
        y_pred.append(torch.max(logits, 1)[1].cpu().detach().numpy())
        classication_loss = - torch.sum(y * torch.log(logits + 1e-8), dim=1).mean()
        J_alpha_val = - L + alpha * classication_loss - U
        
        total_L_val += L.item()
        total_U_val += U.item()
        total_classification_loss_val += classication_loss.item()
        total_loss_val += J_alpha_val.item()
        accuracy_val += torch.mean((torch.max(logits, 1)[1].data == torch.max(y, 1)[1].data).float())

    print("Epoch: {}".format(epoch+1))
    print("[Train]\t\t L: {:.2f}, U: {:.2f}, class: {:.2f}, J_a: {:.2f}, accuracy: {:.2f}".format(total_L_train / m_train, total_U_train / m_train, total_classification_loss_train / m_train, total_loss_train / m_train, accuracy_train / m_train))
    print("[Validation]\t L: {:.2f}, U: {:.2f}, class: {:.2f}, J_a: {:.2f}, accuracy: {:.2f}".format(total_L_val / m_val, total_U_val / m_val, total_classification_loss_val / m_val, total_loss_val / m_val, accuracy_val / m_val))
    
    accuracyTrain.append(accuracy_train / m_train)
    accuracyVal.append(accuracy_val / m_val)
    LTrain.append(total_L_train / m_train)
    LVal.append(total_L_val / m_val)
    UTrain.append(total_U_train / m_train)
    UVal.append(total_U_val / m_val)
    classTrain.append(total_classification_loss_train / m_train)
    classVal.append(total_classification_loss_val / m_val)
    JAlphaTrain.append(total_loss_train / m_train)
    JAlphaVal.append(total_loss_val / m_val)
    
    plt.figure(1, figsize=(20, 20))
    plt.subplot(321)
    plt.plot(accuracyTrain, 'r', label='train acc')
    plt.plot(accuracyVal, 'b', label='val acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.subplot(322)
    plt.plot(classTrain, 'r', label='train class')
    plt.plot(classVal, 'b', label='val class')
    plt.xlabel('Epoch')
    plt.ylabel('Classification')
    plt.legend()
    plt.subplot(323)
    plt.plot(LTrain, 'r', label='train L')
    plt.plot(LVal, 'b', label='val L')
    plt.xlabel('Epoch')
    plt.ylabel('L')
    plt.legend()
    plt.subplot(324)
    plt.plot(UTrain, 'r', label='train U')
    plt.plot(UVal, 'b', label='val U')
    plt.xlabel('Epoch')
    plt.ylabel('U')
    plt.legend()
    plt.subplot(325)
    plt.plot(JAlphaTrain, 'r', label='train J-alpha')
    plt.plot(JAlphaVal, 'b', label='val J-alpha')
    plt.xlabel('Epoch')
    plt.ylabel('J-alpha')
    plt.legend()
    #Uncomment for latent space visualization
    #plt.subplot(326)
    #plt.title("Latent space: R:Opacity, G:Not-Normal, B:Normal")
    #plt.xlabel('Dimension 1')
    #plt.ylabel('Dimension 2')
    #latent = np.vstack(latent)
    #latent = np.array(latent, dtype=np.float32).reshape(-1, z_dim)
    #latent = PCA(n_components=2).fit_transform(latent)
    #classes = np.argmax(label_val, axis=1)
    #k = 0
    #for z in latent:
    #    if (classes[k] == 0):
    #        plt.scatter(z[0], z[1], c='red', marker='o')
    #    elif (classes[k] == 1):
    #        plt.scatter(z[0], z[1], c='green', marker='o')
    #    elif (classes[k] == 2):
    #        plt.scatter(z[0], z[1], c='blue', marker='o')
    #    k = k + 1    
    plt.show()
    clear_output(wait=True)

 ### Confusion Matrix

In [None]:
y_pred = np.concatenate( y_pred, axis=0 )
from sklearn.metrics import confusion_matrix
conf_matrix = confusion_matrix(classes, np.vstack(y_pred[0:classes.shape[0]]))
print(conf_matrix)

### Conditional sampling
We now create some samples from the model given the class value.

In [None]:
from torch.distributions.normal import Normal

model.eval()

z = Variable(torch.randn(2, z_dim))
z = z.cuda()
classValue = 0
y = torch.Tensor(indices_to_one_hot([classValue, classValue],y_dim))
y = y.cuda()
x_mu, x_log_var = model.sample(z, y)
norm = Normal(x_mu,x_log_var)
x = norm.sample()

In [None]:
f, axarr = plt.subplots(1, 2, figsize=(40, 40))

samples = x.data.view(-1, image_resize, image_resize).cpu().numpy()

for i, ax in enumerate(axarr.flat):
    ax.imshow(samples[i], cmap='bone')
    ax.axis("off")

### Save model and plot

In [None]:
PATHMODEL = '/home/stce/Scaricati/Unumed/Models/beta01'
PATHFIGURE = '/home/stce/Scaricati/Unumed/Figure/training.npz'
torch.save(model.state_dict(), PATHMODEL)
torch.save(model.state_dict(), PATHMODEL)
np.savez(PATHFIGURE,accuracyTrain=accuracyTrain,
accuracyVal=accuracyVal,
classTrain=classTrain,
classVal=classVal,
LTrain=LTrain,
LVal=LVal,
UTrain=UTrain,
UVal=UVal,
JAlphaTrain=JAlphaTrain,
JAlphaVal=JAlphaVal,
conf_matrix=conf_matrix,
latent=latent,
classes=classes)

