In [16]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda, Resize
from medmnist import INFO, TissueMNIST, PneumoniaMNIST, DermaMNIST, OrganAMNIST
from transformers import ViTForImageClassification
from transformers import AdamW

def load_data(data_flag='pneumoniamnist', batch_size=16):
    n_classes = len(INFO[data_flag]['label'])
    transform = Compose([
        Resize((224, 224)),
        ToTensor(),
        Lambda(lambda x: x.repeat(3, 1, 1)),
        Normalize(mean=[0.5]*3, std=[0.5]*3)
    ])

    train_dataset = PneumoniaMNIST(split='train', transform=transform, download=True)
    val_dataset = PneumoniaMNIST(split='val', transform=transform, download=True)
    test_dataset = PneumoniaMNIST(split='test', transform=transform, download=True)

    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader, n_classes

def setup_model(n_classes):
    model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=n_classes)
    return model

def train_and_evaluate(model, train_loader, val_loader, device):
    print(f'Using device: {device}')
    model.to(device)
    optimizer = AdamW(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    best_accuracy = 0
    best_model_path = 'organA224.pth'

    model.train()
    for epoch in range(10):
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            labels = labels.squeeze(1).long()

            outputs = model(images).logits
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()
        total, correct = 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                labels = labels.squeeze(1).long()

                outputs = model(images).logits
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        print(f'Epoch {epoch+1}, Loss: {loss.item()}, Validation Accuracy: {accuracy}%')

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), best_model_path)
            print(f"Best model saved with accuracy: {accuracy}% at epoch {epoch+1}")

    print(f'Training complete. Best model was saved with an accuracy of {best_accuracy}%.')




def main():
    device = "mps" if torch.backends.mps.is_available() else "cpu"
    #device = "cuda" if torch.cuda.is_available() else "cpu"
    train_loader, val_loader, test_loader, n_classes = load_data()
    model = setup_model(n_classes)
    train_and_evaluate(model, train_loader, val_loader, device)

if __name__ == '__main__':
    #main()
    pass



In [17]:
def EvaluateModel(loader, model, device, label):
    model.to(device)

    model.eval()
    total, correct = 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        labels = labels.squeeze(1).long()
        outputs = model(images).logits
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        #print('Predicted: ', predicted)
        #print('Actual: ', labels)

    accuracy = 100 * correct / total
    print(f'{label} Accuracy: {accuracy:.2f}%')

train_loader, val_loader, test_loader, n_classes = load_data()
device = "mps" if torch.backends.mps.is_available() else "cpu"
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=n_classes)
model.load_state_dict(torch.load('pneumonia224.pth'))
EvaluateModel(train_loader, model, device, 'Training')
EvaluateModel(val_loader, model, device, 'Validation')
EvaluateModel(test_loader, model, device, 'Test')




Using downloaded and verified file: /Users/oscarrosman/.medmnist/pneumoniamnist.npz
Using downloaded and verified file: /Users/oscarrosman/.medmnist/pneumoniamnist.npz
Using downloaded and verified file: /Users/oscarrosman/.medmnist/pneumoniamnist.npz


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Training Accuracy: 99.64%
Validation Accuracy: 97.71%
Test Accuracy: 87.66%


In [11]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(model, image):
    # Put model in evaluation mode
    model.eval()
    
    # Forward pass to get attention weights
    with torch.no_grad():
        outputs = model(image)
        attention_weights = outputs.attentions  # Extract attention weights
    
    # Plot attention maps
    num_layers = len(attention_weights)
    num_heads = attention_weights[0].size(1)
    
    fig, axs = plt.subplots(num_layers, num_heads, figsize=(15, 15))
    
    for layer in range(num_layers):
        for head in range(num_heads):
            sns.heatmap(attention_weights[layer][0, head].cpu().numpy(), ax=axs[layer, head], cmap="viridis", cbar=False)
            axs[layer, head].set_title(f"Layer {layer+1}, Head {head+1}")
            axs[layer, head].axis('off')
    
    plt.tight_layout()
    plt.show()

# Assuming `image` is a single image tensor of shape (batch_size, channels, height, width)
image = next(iter(val_loader))[0][0].unsqueeze(0).to(device)  # Taking one image from validation loader
visualize_attention(model, image)

for image, label in val_loader:
    image, label = image.to(device), label.to(device)
    visualize_attention(model, image)

TypeError: object of type 'NoneType' has no len()