In [None]:
##############################
######## AJ IGLESIAS #########
##############################

### Malaria Detection Neural Network ###
### PyTorch Implementation ###

##########Import libraries for network ##############
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torchvision import datasets, transforms, models
from torch.utils.data.sampler import SubsetRandomSampler

In [None]:
#Define transforms for training, validation and testing datasets#

#Train transform
trainTransforms = transforms.Compose([transforms.RandomRotation(30), #rotate image 30 degrees
                                     transforms.RandomResizedCrop(224), #crop to image size of 224 x 224
                                     transforms.RandomVerticalFlip(), #flip image vertically
                                     transforms.ToTensor(), #convert to tensor
                                     transforms.Normalize([0.485, 0.456, 0.406],
                                                         [0.229, 0.224, 0.225])]) #Normalized mini-batches for the images because we use a pre-trained model (ResNet50) which expects input images normalized in this way

#Test transform
testTransforms = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224), #crops image into square crop of size 224 x 224
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406],
                                                         [0.229, 0.224, 0.225])])


#Validation transform
validationTransforms = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406],
                                                         [0.229, 0.224, 0.225])])



In [None]:
images = 'cell_images/' # data within the directory hence a very simple path
trainData = datasets.ImageFolder(images, transform = trainTransforms) #data loader that arranges the malaria cell images 

In [None]:
# percentage of training set to use as validation set
validSize = 0.3 #30% of training set to use for validation
test_size = 0.1

#make Float tensor
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#training indices that will be used for validation
num_Train = len(trainData)
indices = list(range(num_Train))

#Shuffle
np.random.shuffle(indices)

#Splits
validSplit = int(np.floor((validSize) * num_Train))
testSplit = int(np.floor((validSize + test_size) * num_Train))

#index per validation, test, and training
validInd, testInd, trainInd = indices[:validSplit], indices[validSplit:testSplit], indices[testSplit:]

#lets see lengths
print(len(validInd), len(testInd), len(trainInd))

#define Samplers for obtaining training and validation batches
trainSampler = SubsetRandomSampler(trainInd)
testSampler = SubsetRandomSampler(testInd)
validSampler = SubsetRandomSampler(validInd)

In [None]:
#Prepare combination of dataset and samplers 

#dataLoaders gives us wrapper access and querying abilities
trainLoad = torch.utils.data.DataLoader(trainData, batch_size = 64, sampler = trainSampler)
testLoad = torch.utils.data.DataLoader(trainData, batch_size = 10, sampler = testSampler)
validLoad = torch.utils.data.DataLoader(trainData, batch_size = 64, sampler = validSampler)

In [None]:
#See the length of the test data loader
len(testLoad)

In [None]:
#use pretrained resnet50 for our model
model = models.resnet50(pretrained=True)

for param in model.parameters():
    param.requires_grad = False
    
model.fc = nn.Linear(2048, 2, bias=True)

fcParameters = model.fc.parameters()

for param in fcParameters:
    param.requires_grad = True
    
model


In [None]:
#Training the network
def trainNet(epochs, model, optimizer, criterion):
    validLoss = np.inf
    
    for epoch in range(1, epochs+1):
        #initialize training and valid loss set to 0.0 so each time it goes through network loss is reset to 0.0
        trainLoss = 0.0
        validLoss = 0.0
        
        #Model Training#
        model.train()
        for batchIndex, (data, target) in enumerate(trainLoad):
            
            #default weights to 0
            optimizer.zero_grad()
            
            output = model(data)
            
            #calc Loss
            loss = criterion(output, target)
            
            #back propogation
            loss.backward()
            
            #gradient
            optimizer.step()
            
            trainLoss = trainLoss + ((1 / (batchIndex + 1)) * (loss.data - trainLoss))
            
            if batchIndex % 100 == 0:
                print('Epoch %d, Batch %d, Loss: %.6f' % (epoch, batchIndex + 1, trainLoss))
                
            
        #evaluate and validate model
        
        model.eval()
        for batchIndex, (data, target) in enumerate(validLoad):
            
            #Average validation loss
            
            output = model(data)
            loss = criterion(output, target)
            validLoss = validLoss + ((1 / (batchIndex + 1)) * (loss.data - validLoss))
            
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(epoch, trainLoss, validLoss))

        
    return model

In [None]:
#Set up our optimizer and criterion
optimizer = optim.SGD(model.fc.parameters(), lr = 0.001) #establish learning rate as 0.0001
criterion = nn.CrossEntropyLoss() #use cross entropy loss function
#Lets run the model for 4 epochs due to computer limitations
trainNet(4, model, optimizer, criterion)

In [None]:
###Test the trained model for accuracy###
def testMod(model, criterion):
    #Capture loss and accuracy
    testLoss = 0.0
    correct = 0.0
    total = 0.0
    
    for batchIndex, (data, target) in enumerate(testLoad):
        #forward function pass
        output = model(data)
        
        #calc loss
        loss = criterion(output, target)
        
        #Update average test loss
        testLoss = testLoss + ((1 / (batchIndex + 1)) * (loss.data - testLoss))
        
        #establish predicted class
        pred = output.data.max(1, keepdim=True)[1]
        
        #compare predictions to true label
        correct += np.sum(np.squeeze(pred.eq(target.data.view_as(pred))).numpy())
        total += data.size(0)
        
    print('Test Loss: {:.6f}\n'.format(testLoss))
    
    #Print algorithm accuracy
    print('\n Test Accuracy: %2d%% (%2d/%2d)' % (100. * correct / total, correct, total))
    

In [None]:
testMod(model, criterion)

In [None]:
from PIL import Image
from glob import glob

def inputImage(image):
    image = Image.open(image)
    predictTransform = transforms.Compose([transforms.Resize(size=(224,224)),
                                          transforms.ToTensor(),
                                          transforms.Normalize([0.485, 0.456, .406],
                                                              [0.229, 0.224, 0.225])])
    image = predictTransform(image)[:3,:,:].unsqueeze(0)
    return image



In [None]:
def predictMalaria(model, class_name, image):
    #function to return the predicted malaria cell
    image = inputImage(image)
    
    model.eval()
    index = torch.argmax(model(image))
    return className[index]

In [None]:
#Run these two functions above for prediction and give an idea of the cell we are looking at 
className = ['Parasitized', 'Uninfected']
infected = np.array(glob('cell_images/Parasitized/*'))
uninfected = np.array(glob('cell_images/Uninfected/*'))

for i in range(5):
    imagePath = infected[i]
    image = Image.open(imagePath)
    if predictMalaria(model, className, imagePath) == 'Parasitized':
        print('Parasitized')
    else:
        print('Uninfected')
    plt.imshow(image)
    plt.show()

for i in range(5):
    imagePath = uninfected[i]
    image = Image.open(imagePath)
    if predictMalaria(model, className, imagePath) == 'Uninfected':
        print('Uninfected')
    else:
        print('Parasitized')
    plt.imshow(image)
    plt.show()