In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torchvision.models import densenet121
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from PIL import Image
import os
import numpy as np

from sklearn.metrics import roc_auc_score

In [2]:
class ChestXrayDataSet(Dataset):
    def __init__(self, data_dir, image_list_file, transform=None, num_samples=100):
        image_names = []
        labels = []
        count = 0
        with open(image_list_file, "r") as f:
            for line in f:
                if count >= num_samples:
                    break
                items = line.split()
                image_name= items[0]
                label = items[1:]
                label = [int(i) for i in label]
                image_name = os.path.join(data_dir, image_name)
                image_names.append(image_name)
                labels.append(label)
                count += 1

        self.image_names = image_names
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        """
        Args:
            index: the index of item

        Returns:
            image and its labels
        """
        image_name = self.image_names[index]
        image = Image.open(image_name).convert('RGB')
        label = self.labels[index]
        if self.transform is not None:
            image = self.transform(image)
        return image, torch.FloatTensor(label)

    def __len__(self):
        return len(self.image_names)


In [3]:
class DenseNet121(nn.Module):
    def __init__(self, out_size):
        super(DenseNet121, self).__init__()
        self.densenet121 = torchvision.models.densenet121(pretrained=True)
        num_ftrs = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(num_ftrs, out_size),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.densenet121(x)
        return x

In [4]:
DATA_DIR = './ChestX-ray14/images'
TRAIN_IMAGE_LIST = 'ChestX-ray14\labels\mytraine.txt'
CKPT_PATH = 'myckpt.pth.tar'
N_CLASSES = 14
BATCH_SIZE = 8

In [5]:
cudnn.benchmark = True

In [6]:
# Initialize and load the model
model = DenseNet121(N_CLASSES).cuda()
model = torch.nn.DataParallel(model).cuda()

if os.path.isfile(CKPT_PATH):
    print("=> loading checkpoint")
    modelCheckpoint = torch.load(CKPT_PATH)['state_dict']
    for k in list(modelCheckpoint.keys()):
        index = k.rindex('.')
        if (k[index - 1] == '1' or k[index - 1] == '2'):
            modelCheckpoint[k[:index - 2] + k[index - 1:]] = modelCheckpoint[k]
            del modelCheckpoint[k]
    model.load_state_dict(modelCheckpoint)
    print("=> loaded checkpoint")
else:
    print("=> no checkpoint found")



=> no checkpoint found


In [7]:
normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

In [8]:
normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

# Define transformations for training images
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
])

In [12]:
train_dataset = ChestXrayDataSet(data_dir=DATA_DIR, image_list_file=TRAIN_IMAGE_LIST, transform=train_transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=False)

In [13]:
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [14]:
model.train()
num_epochs = 100
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.cuda(), labels.cuda()

        optimizer.zero_grad()

       
        outputs = model(inputs)
        
       
        loss = criterion(outputs, labels)
        
       
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 10 == 9:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 10))
            running_loss = 0.0


[1,    10] loss: 0.448
[2,    10] loss: 0.327
[3,    10] loss: 0.322
[4,    10] loss: 0.269
[5,    10] loss: 0.259
[6,    10] loss: 0.240
[7,    10] loss: 0.236
[8,    10] loss: 0.191
[9,    10] loss: 0.197
[10,    10] loss: 0.186
[11,    10] loss: 0.177
[12,    10] loss: 0.173
[13,    10] loss: 0.179
[14,    10] loss: 0.140
[15,    10] loss: 0.129
[16,    10] loss: 0.106
[17,    10] loss: 0.101
[18,    10] loss: 0.081
[19,    10] loss: 0.105
[20,    10] loss: 0.114
[21,    10] loss: 0.102
[22,    10] loss: 0.110
[23,    10] loss: 0.095
[24,    10] loss: 0.068
[25,    10] loss: 0.070
[26,    10] loss: 0.085
[27,    10] loss: 0.063
[28,    10] loss: 0.071
[29,    10] loss: 0.065
[30,    10] loss: 0.075
[31,    10] loss: 0.083
[32,    10] loss: 0.065
[33,    10] loss: 0.053
[34,    10] loss: 0.050
[35,    10] loss: 0.045
[36,    10] loss: 0.048
[37,    10] loss: 0.032
[38,    10] loss: 0.039
[39,    10] loss: 0.034
[40,    10] loss: 0.048
[41,    10] loss: 0.044
[42,    10] loss: 0.042
[

In [15]:
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}, CKPT_PATH)


In [None]:
def accuracy():
    