In [1]:
!pip install torch torchvision numpy matplotlib



In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.models.video import r3d_18
import numpy as np
import matplotlib.pyplot as plt

In [3]:
# Define synthetic 3D dataset
class Synthetic3DDataset(Dataset):
    def __init__(self, num_samples=100, img_size=(32, 32, 32), num_classes=2):
        self.num_samples = num_samples
        self.img_size = img_size
        self.num_classes = num_classes
        self.data = torch.rand(num_samples, 1, *img_size)  # Random 3D single-channel images
        self.labels = torch.randint(0, num_classes, (num_samples,))  # Random labels

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [4]:
# Modify the ResNet model to accept 1-channel input
class MRI3DResNet(nn.Module):
    def __init__(self, num_classes=2):
        super(MRI3DResNet, self).__init__()
        self.model = r3d_18(pretrained=False)  # Load 3D ResNet-18

        # Modify the first convolutional layer to accept 1-channel input
        self.model.stem[0] = nn.Conv3d(
            in_channels=1,  # Change to 1 channel
            out_channels=64,
            kernel_size=(3, 7, 7),
            stride=(1, 2, 2),
            padding=(1, 3, 3),
            bias=False
        )

        # Change the output layer for the number of classes
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

In [5]:
# Initialize dataset and data loader
dataset = Synthetic3DDataset(num_samples=100, img_size=(32, 32, 32), num_classes=2)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [6]:
# Initialize the model
num_classes = 2
model = MRI3DResNet(num_classes=num_classes)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)



In [7]:
# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss / len(dataloader):.4f}")

print("Training complete!")

Epoch [1/5], Loss: 0.7321


In [11]:
# Visualize a sample and its prediction
def visualize_sample(data, label, output=None, slice_idx=None):
    """
    Visualize a slice of the 3D data.
    """
    if slice_idx is None:
        slice_idx = data.shape[2] // 2  # Middle slice along depth dimension

    # Extract the slice: shape (channels, height, width)
    slice_2d = data[0, :, slice_idx, :, :].cpu().numpy()  # Select slice in depth direction

    # Plot the image
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(slice_2d[0], cmap="gray")  # Only show the first channel
    plt.title(f"Input Slice (Label: {label})")
    plt.axis("off")

    if output is not None:
        pred_label = output.argmax(dim=1).item()
        plt.subplot(1, 2, 2)
        plt.imshow(slice_2d[0], cmap="gray")  # Only show the first channel
        plt.title(f"Predicted: {pred_label}")
        plt.axis("off")

    plt.show()

# Example visualization
model.eval()
with torch.no_grad():
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)

        # Visualize the first sample
        visualize_sample(inputs[0], labels[0].item(), outputs[0])
        break