In [None]:
import os
import torch
import torchvision

import torch.utils
import torch.utils.data
import torch.nn.functional as F

from PIL import Image
from matplotlib import pyplot as plt
import wandb

torch.manual_seed(0)

In [None]:
torch.cuda.is_available()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class_names = ['COVID', 'NORMAL', 'PNEUMONIA']
root_dir = 'data'
source_dir = ['COVID', 'NORMAL', 'PNEUMONIA']

In [None]:
class ChestXRayDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, source_dir, transform=None):
        self.root_dir = root_dir
        self.source_dir = source_dir
        self.transform = transform
        self.images = []
        self.labels = []
        for i, d in enumerate(source_dir):
            for f in os.listdir(os.path.join(root_dir, d)):
                self.images.append(os.path.join(root_dir, d, f))
                self.labels.append(i)
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dataset = ChestXRayDataset(root_dir, source_dir, transform)

In [None]:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

In [None]:
batch_size = 6

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=6, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=6, shuffle=False)

In [None]:
model = torchvision.models.resnet18(weights=torchvision.models.resnet.ResNet18_Weights.DEFAULT)
model.fc = torch.nn.Linear(512, 3)
model.to(device)

In [None]:
wandb.init(
    # set the wandb project where this run will be logged
    project="XRay-COVID",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.001,
    "batch_size": 6,
    "conv_kernel": 3,
    "epochs": 10,
    }
)

In [None]:
criteria = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
columns = ['image', 'guess', 'truth', 'COVID', 'NORMAL', 'PNEUMONIA']
test_table = wandb.Table(columns=columns)

In [None]:
def log_test_predictions(images, labels, outputs, predicted, test_table):
    scores = F.softmax(outputs.data, dim=1)
    log_scores = scores.cpu().numpy()
    log_images = images.cpu().numpy()
    log_labels = labels.cpu().numpy()
    log_preds = predicted.cpu().numpy()

    for i, l, p, s in zip(log_images, log_labels, log_preds, log_scores):
        i = i.transpose(1, 2, 0)
        test_table.add_data(wandb.Image(i), l, p, *s)

In [None]:

def train(epoch):
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criteria(outputs, labels)
        loss.backward()
        optimizer.step()

        wandb.log({"loss": loss})
        if (i + 1) % 10 == 0:
            print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f' % (epoch + 1, num_epochs, i + 1, len(train_dataset) // batch_size, loss.item()))

def test():
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            log_test_predictions(images, labels, outputs, predicted, test_table)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        wandb.log({"epoch": epoch, "accuracy": 100 * correct / total})
        print('Test Accuracy of the model on the test images: %d %%' % (100 * correct / total))


num_epochs = 10
total_step = len(train_loader)
for epoch in range(num_epochs):
    train(epoch)
    test()
    wandb.log({"test_predictions" : test_table})


In [None]:
wandb.finish()