In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import loadMedicalMNIST as load_data
import Model as neural_network
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
from torchvision import transforms
from pathlib import Path
from torchsummary import summary

In [2]:
root_dir = 'Documents/datasets/MedicalMNIST'
df = load_data.get_labels_df(root_dir)
dataset = load_data.MedicalMNIST(df, root_dir, transform=load_data.data_transform())
print(len(dataset))

                  0  1
0  Hand/001498.jpeg  0
1  Hand/004360.jpeg  0
2  Hand/005988.jpeg  0
3  Hand/001162.jpeg  0
4  Hand/009552.jpeg  0
58954


In [3]:
train_set, test_set = torch.utils.data.random_split(dataset,
                                                   [48954,10000])

In [4]:
#Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
#Set the hyperparameters
n_epochs = 10
lr = 0.001
in_channels = 3
output_classes = 6
batch_size = 64
train_loader = DataLoader(train_set, batch_size=(batch_size), shuffle=True)
test_loader = DataLoader(test_set, batch_size=(batch_size), shuffle=True)

In [6]:
model = neural_network.MNIST_CNN(in_channels, output_classes).to(device)
print(summary(model,input_size = (3,64,64)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 62, 62]             896
         MaxPool2d-2           [-1, 32, 31, 31]               0
            Conv2d-3           [-1, 16, 29, 29]           4,624
         MaxPool2d-4           [-1, 16, 14, 14]               0
           Flatten-5                 [-1, 3136]               0
           Dropout-6                 [-1, 3136]               0
            Linear-7                   [-1, 64]         200,768
            Linear-8                    [-1, 6]             390
Total params: 206,678
Trainable params: 206,678
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 1.35
Params size (MB): 0.79
Estimated Total Size (MB): 2.18
----------------------------------------------------------------
None


In [7]:
#Loss and the optimizer
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = lr)

In [8]:
#Train the network
def train(model, n_epochs, train_loader):
    for epoch in range(n_epochs):
        for batch, (data, targets) in enumerate(train_loader):
            data = data.to(device=device)
            targets = targets.to(device=device)
            
            #Forward
            scores = model(data)
            loss = loss_function(scores, targets)
            
            #Backward
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient descent
            optimizer.step()
            
        print(epoch, "Current Loss:", loss)
train(model, n_epochs, train_loader)

0 Current Loss: tensor(0.0034, grad_fn=<NllLossBackward0>)
1 Current Loss: tensor(0.0005, grad_fn=<NllLossBackward0>)
2 Current Loss: tensor(0.0001, grad_fn=<NllLossBackward0>)
3 Current Loss: tensor(0.0004, grad_fn=<NllLossBackward0>)
4 Current Loss: tensor(3.1569e-06, grad_fn=<NllLossBackward0>)
5 Current Loss: tensor(4.6934e-05, grad_fn=<NllLossBackward0>)
6 Current Loss: tensor(4.4806e-07, grad_fn=<NllLossBackward0>)
7 Current Loss: tensor(2.8177e-06, grad_fn=<NllLossBackward0>)
8 Current Loss: tensor(3.0419e-07, grad_fn=<NllLossBackward0>)
9 Current Loss: tensor(0.0040, grad_fn=<NllLossBackward0>)


In [9]:
def evaluate(loader, model):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for x,y in loader:
            x = x.to(device=device)
            y = y.to(device=device)
            
            scores = model(x)
            _, pred = scores.max(1)
            correct += (pred == y).sum()
            total += pred.size(0)
        print("Accuracy:", correct/total*100, "%")
evaluate(test_loader, model)

Accuracy: tensor(99.2000) %
