In [None]:
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os
import numpy as np
import torch.nn as nn
from transformers import ViTModel, ViTConfig
import pandas as pd
import pydicom
from PIL import Image

In [None]:
torch.manual_seed(42)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
def dicom_to_pil(dicom_file):
    dicom = pydicom.dcmread(dicom_file)
    
    image = dicom.pixel_array.astype(float)
    
    image = ((np.maximum(image, 0) / image.max()) * 255.0).astype(np.uint8)
    
    pil_image = Image.fromarray(image).convert('RGB')
    
    return pil_image

In [None]:
class BrainTumorDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        folder_name = self.data.iloc[idx, 0]
        label = self.data.iloc[idx, 1]
        folder_path = os.path.join(self.root_dir, folder_name)
        
        dicom_files = [f for f in os.listdir(folder_path) if f.endswith('.dcm')]
        
        dicom_file = np.random.choice(dicom_files)
        img_path = os.path.join(folder_path, dicom_file)
        
        image = dicom_to_pil(img_path)

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
dataset = BrainTumorDataset(csv_file='train.csv', root_dir='data', transform=transform)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [None]:
class HybridVisionTransformer(nn.Module):
    def __init__(self, num_classes=2):
        super(HybridVisionTransformer, self).__init__()
        
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        
        self.fc1 = nn.Linear(768, 256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        vit_output = self.vit(x, output_attentions=True)
        hidden_states = vit_output.last_hidden_state[:, 0, :]  # Use [CLS] token
        attentions = vit_output.attentions
        
        x = self.fc1(hidden_states)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x, attentions

In [None]:
'''
The __init__ method initializes the model:

It sets up a Vision Transformer (ViT) model pre-trained on ImageNet21k.
It creates a fully connected layer (fc1) that reduces the dimensionality from 768 to 256.
It sets up a ReLU activation and a dropout layer for regularization.
It creates another fully connected layer (fc2) that maps from 256 to the number of classes.
'''

In [None]:
'''
The forward method defines how data flows through the network:

It passes the input through the ViT model, getting both the hidden states and attention weights.
It extracts the [CLS] token representation (the first token, used for classification).
The [CLS] token goes through the fully connected layers, ReLU activation, and dropout.
The final output is the classification logits and the attention weights from the ViT.
'''

In [None]:
model = HybridVisionTransformer(num_classes=2)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs, _ = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [None]:
def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs, _ = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [None]:
num_epochs = 1
for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)
    
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    print()