In [1]:
import numpy as np
import torchvision.models as models
import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, hamming_loss
from medmnist import ChestMNIST
from PIL import Image


In [2]:
def load_data(data_split, image_nxn_size, n_observations):
    data = ChestMNIST(split=data_split, download=True, size=image_nxn_size)

    if n_observations > 0:
        images = data.imgs[0:n_observations]
        labels = data.labels[0:n_observations]
    else:      
        images = data.imgs
        labels = data.labels

    del data

    return images, labels

In [3]:
train_images, train_labels = load_data(data_split="train", image_nxn_size=224, n_observations=1000)
validation_images, validation_labels = load_data(data_split="val", image_nxn_size=224, n_observations=1000)
test_images, test_labels = load_data(data_split="test", image_nxn_size=224, n_observations=1000)

Using downloaded and verified file: /Users/thollenbeak/.medmnist/chestmnist_224.npz
Using downloaded and verified file: /Users/thollenbeak/.medmnist/chestmnist_224.npz
Using downloaded and verified file: /Users/thollenbeak/.medmnist/chestmnist_224.npz


In [4]:
def preprocess_data(image_set):
    preprocess = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    transformed_images = []

    for image in tqdm.tqdm(image_set):
        image = np.float32(image) / 255.0
        image = Image.fromarray(image)
        transformed_images.append(preprocess(image))

    return torch.stack(transformed_images)

In [5]:
x_train_tensor = preprocess_data(train_images)
x_validation_tensor = preprocess_data(validation_images)
x_test_tensor = preprocess_data(test_images)

100%|██████████| 1000/1000 [00:00<00:00, 2044.73it/s]
100%|██████████| 1000/1000 [00:00<00:00, 2097.33it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1157.97it/s]


In [6]:
# Debugging
# num_images = len(transformed_images)
# image_size = transformed_images[0].numel()  # Number of elements in one image
# dtype_size = transformed_images[0].element_size()  # Size of each element in bytes
# total_memory = num_images * image_size * dtype_size
# print(f"Total memory required: {total_memory / (1024 ** 3):.2f} GB")

In [7]:
y_train_tensor = torch.tensor(train_labels)
y_validation_tensor = torch.tensor(validation_labels)
y_test_tensor = torch.tensor(test_labels)

train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
validation_dataset = TensorDataset(x_validation_tensor, y_validation_tensor)
test_dataset = TensorDataset(x_test_tensor, y_test_tensor)

# Create DataLoaders for efficient training and testing data handling
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=False)  
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)  

In [15]:
model = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.DEFAULT, progress=True)
n_features = 9

for param in model.features[:n_features].parameters():
    param.requires_grad = False

for param in model.features[n_features:].parameters():
    param.requires_grad = True

model.classifier[1] = nn.Conv2d(512, train_labels.shape[1], kernel_size=(1, 1), stride=(1, 1))
model.classifier[2] = nn.Identity()

In [16]:
optimizer = optim.Adam(model.parameters(), lr = 0.1)
#scheduler = StepLR(optimizer, step_size = 2, gamma = 0.5)
criterion = nn.BCEWithLogitsLoss()

In [10]:
# with torch.no_grad():
#     for data, label in train_loader:
#         #print(data.shape, label.shape)
#         validation_outputs = model(data)
#         validation_targets = label
#         print(validation_outputs)
#         #print(criterion(validation_outputs, validation_targets))
#         break

In [17]:
best_model_path = "best_transfer_learning_model.pth"
best_loss = np.inf
best_accuracy = 0
best_epoch = 0
sigmoid_threshold = 0.5

for epoch in range(5):
    model.train()
    
    for inputs, targets in tqdm.tqdm(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        targets = targets.float()

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

    model.eval()
    all_targets = []
    all_predictions = []
    validation_loss = 0.0
    validation_accuracy = 0.0
    with torch.no_grad():
        for validation_inputs, validation_targets in validation_loader:
            validation_outputs = model(validation_inputs)
            validation_targets = validation_targets.float()
            validation_loss += criterion(validation_outputs, validation_targets)

            probabilities = torch.sigmoid(validation_outputs)
            predictions = (probabilities > sigmoid_threshold).float()

            all_targets.extend(validation_targets)
            all_predictions.extend(predictions)

    validation_accuracy = accuracy_score(all_targets, all_predictions)
    validation_loss /= len(validation_loader.dataset)

    if validation_loss < best_loss:
        best_loss = validation_loss
    
    # Save the model with the best accuracy
    if validation_accuracy > best_accuracy:
        best_accuracy = validation_accuracy
        torch.save(model.state_dict(), best_model_path)

    print(f"Epoch: {epoch + 1}, Validation Loss: {validation_loss}, Accuracy: {validation_accuracy}")

100%|██████████| 16/16 [00:14<00:00,  1.11it/s]


Epoch: 1, Validation Loss: 0.0033600383903831244, Accuracy: 0.564


100%|██████████| 16/16 [00:13<00:00,  1.15it/s]


Epoch: 2, Validation Loss: 0.002948910463601351, Accuracy: 0.564


100%|██████████| 16/16 [00:13<00:00,  1.16it/s]


Epoch: 3, Validation Loss: 0.0028402081225067377, Accuracy: 0.564


100%|██████████| 16/16 [00:14<00:00,  1.09it/s]


Epoch: 4, Validation Loss: 0.0028804230969399214, Accuracy: 0.564


100%|██████████| 16/16 [00:13<00:00,  1.15it/s]


Epoch: 5, Validation Loss: 0.002876340877264738, Accuracy: 0.564


In [12]:
# Implement the evaluation phase
# Try using the CNN as a feature extractor for traditional ML methods
# Include a larger classification layer for finetuning
# Address potential class imbalance