In [1]:
from medmnist import ChestMNIST
from PIL import Image
from torchvision import transforms
import torchvision.models as models
import numpy as np
import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import MinMaxScaler, LabelEncoder

In [2]:
# Load the ChestMNIST dataset
train_dataset = ChestMNIST(split="train", download=True, size=224)

n = 1000

train_images = train_dataset.imgs[0:n]
train_labels = train_dataset.labels[0:n]

del train_dataset

Using downloaded and verified file: /Users/thollenbeak/.medmnist/chestmnist_224.npz


In [4]:
preprocess = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transformed_images = []

for image in tqdm.tqdm(train_images):
    image = np.float32(image) / 255.0
    image = Image.fromarray(image)
    transformed_images.append(preprocess(image))

x_train_tensor = torch.stack(transformed_images)

100%|██████████| 1000/1000 [00:00<00:00, 2250.96it/s]


In [5]:
# Debugging
num_images = len(transformed_images)
image_size = transformed_images[0].numel()  # Number of elements in one image
dtype_size = transformed_images[0].element_size()  # Size of each element in bytes
total_memory = num_images * image_size * dtype_size
print(f"Total memory required: {total_memory / (1024 ** 3):.2f} GB")

Total memory required: 0.56 GB


In [6]:
y_train_tensor = torch.tensor(train_labels)
#y_validation_tensor = torch.tensor(labels)
#y_test_tensor = torch.tensor(labels)

train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
#validation_dataset = TensorDataset(x_validation_tensor, y_validation_tensor)
#test_dataset = TensorDataset(x_test_tensor, y_test_tensor)

# Create DataLoaders for efficient training and testing data handling
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
#test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)  

In [7]:
model = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.DEFAULT, progress=True)

for param in model.features[:11].parameters():
    param.requires_grad = False

model.classifier[1] = nn.Conv2d(512, train_labels.shape[1], kernel_size=(1, 1), stride=(1, 1))
model.classifier[2] = nn.Identity()


In [8]:
optimizer = optim.Adam(model.parameters(), lr = 0.001)
scheduler = StepLR(optimizer, step_size = 2, gamma = 0.5)
criterion = nn.BCEWithLogitsLoss()

In [11]:
for epoch in range(5):
    model.train()
    
    for inputs, targets in tqdm.tqdm(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        targets = targets.float()

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        # Implement validation step later

100%|██████████| 16/16 [00:12<00:00,  1.24it/s]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]
100%|██████████| 16/16 [00:12<00:00,  1.24it/s]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]
