In [6]:
%pip install boto3 nibabel numpy matplotlib torch torchvision torchaudio

Note: you may need to restart the kernel to use updated packages.


In [7]:
import boto3
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import io
import tempfile
import os
import torch

# Check if GPU is available and set the device accordingly
if torch.cuda.is_available():
    print("CUDA is available. Using GPU.")
    device = torch.device('cuda')
    print(f"GPU device name: {torch.cuda.get_device_name(0)}")
else:
    print("CUDA is not available. Using CPU.")
    device = torch.device('cpu')
print(f"Using device: {device}")

# Initialize S3 resource and specify bucket and folder details
s3 = boto3.resource('s3')
bucket_name = 'chemocraft-data'
folder_path = 'MICCAI_BraTS2020_TrainingData/'
# folder_path = 'Data/BraTS20_Training_369 copy/'
bucket = s3.Bucket(bucket_name)

def plot_slice(data, crop, slice_idx, filename):
    # Crop the specified slice
    slice_2d = data[:, :, slice_idx]
    cropped_slice = slice_2d[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]
    
    # Display the cropped slice with matplotlib
    plt.figure(figsize=(6, 6))
    plt.imshow(cropped_slice, cmap='gray')
    plt.title(f'Slice {slice_idx} of {filename}')
    plt.axis('off')  # Hide axes for cleaner display
    plt.show()

def savePNG(data, crop, filename):    
    # Prepare directory structure
    fileWOext = filename.split(".")[0]
    TrainingCount = fileWOext.split("_")[-2]
    ScanType = fileWOext.split("_")[-1]
    slice_path = f"brain_slices/{TrainingCount}/{ScanType}/"
    print(f"Saving in directory: {slice_path}")

    # Iterate through each slice in the Z-Dimiension data and save as PNG
    for slice_idx in range(data.shape[2]):
        # Crop each slice
        slice_2d = data[:, :, slice_idx]
        cropped_slice = slice_2d[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]
        png_filename = f"{slice_path}{slice_idx}.png"
        
        # Local Saving
        # try:
        #     # Create directories as needed and save each slice
        #     os.makedirs(slice_path, exist_ok=True)
        #     mpimg.imsave(png_filename, cropped_slice, cmap='gray')
        # except Exception as e:
        #     print(f"ERROR: directory could not be made due to {e}")
        
        # Upload each PNG to S3
        with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_png:
            mpimg.imsave(temp_png.name, cropped_slice, cmap='gray')
            temp_png.flush()
            temp_png.seek(0)
            temp_png_name = temp_png.name  # Store the name to use it after the file is closed

        try:
            s3.Bucket(bucket_name).upload_file(temp_png_name, f"Akshay/{png_filename}")
            os.remove(temp_png_name)
        except Exception as e:
            print(f"ERROR: Could not upload or delete temporary PNG file due to {e}")

def render_nii_from_s3(filename, path):
    print(f"Fetching file: {filename}")

    try:
        obj = bucket.Object(path + filename)
        file_stream = io.BytesIO(obj.get()['Body'].read())
    except s3.meta.client.exceptions.NoSuchKey as e:
        print(f"ERROR: The specified key does not exist: {path + filename}")
        return
    except Exception as e:
        print(f"ERROR: An unexpected error occurred: {e}")
        return

    with tempfile.NamedTemporaryFile(suffix='.nii', delete=False) as temp_file:  # Prevent auto-deletion
        temp_file.write(file_stream.getvalue())
        temp_file.flush()

        temp_file_path = temp_file.name
        print(f"Temporary file created: {temp_file_path}")

    try:
        img = nib.load(temp_file_path)
        data = img.get_fdata()

        print(f"Data shape for {filename}: {data.shape}")
        
        if data.size == 0:
            print(f"No data found in {filename}")
            return

        # Define crop dimensions
        cropleft = 25
        cropright = data.shape[0] - 15
        cropbottom = data.shape[1] - 40
        croptop = 40        
        crop = np.array([[croptop, cropbottom], [cropleft, cropright]])
        
        # Save the PNGs and plot a sample slice
        savePNG(data, crop, filename)
        
        # slice_idx = 88  # Choose a slice index for sample display
        # plot_slice(data, crop, slice_idx, filename)

    except Exception as e:
        print(f"Error loading file {filename}: {e}")
        
    finally:
        try:
            os.remove(temp_file_path)
            print(f"Deleted temporary file: {temp_file_path}")
        except OSError as cleanup_error:
            print(f"Error deleting temp file: {cleanup_error}")

def find_and_render_nii_files():
    found_files = False

    subfolders = set()  # use a set to ensure unique subfolder names
    for obj in bucket.objects.filter(Prefix=folder_path):
        # Get the path after the 'Data/' prefix and split it by '/'
        path_parts = obj.key[len(folder_path):].split('/')
        
        # Check if there's at least one part (indicating a subfolder)
        if len(path_parts) > 1:
            subfolders.add(f'{path_parts[0]}/')  # Add the subfolder name
            
    subfolders = sorted(subfolders)

    print(f"Root Directory: {folder_path.split('/')[0]}")
    # print(subfolders)

    for subfolder in subfolders:
        path = folder_path + subfolder
        print(f"Reading S3 in {path}")
        for obj in bucket.objects.filter(Prefix=path):
            if obj.key.endswith('.nii'):
                print(f"path: {path}")
                found_files = True
                filename = obj.key.split('/')[-1]  # Extract filename from path
                print(f"Found .nii file: {filename}")
                render_nii_from_s3(filename, path)

    if not found_files:
        print(f"No .nii files found in the folder {folder_path}")

# Main function
# find_and_render_nii_files()


# Old implementation of File Reading
# for obj in bucket.objects.filter(Prefix=folder_path):
#     if obj.key.endswith('.nii'):
#         found_files = True
#         filename = obj.key.split('/')[-1]  # Extract filename from path
#         print(f"Found .nii file: {filename}")
#         # render_nii_from_s3(filename)


# if not found_files:
#     print(f"No .nii files found in the folder {folder_path}")

CUDA is available. Using GPU.
GPU device name: NVIDIA GeForce GTX 1650
Using device: cuda


In [None]:
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import boto3
import io

import torch.nn as nn
import torch.optim as optim
# Define the S3 bucket and other variables
s3 = boto3.resource('s3')
bucket_name = 'your-bucket-name'
bucket = s3.Bucket(bucket_name)

folder_prefix = "Akshay/brain_slices/"
ngpu = torch.cuda.device_count()  # Number of GPUs available. Use 0 for CPU mode.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

class S3ImageDataset(Dataset):
    def __init__(self, bucket, folder_prefix, transform=None):
        self.bucket = bucket
        self.folder_prefix = folder_prefix
        self.transform = transform
        self.image_keys = [obj.key for obj in bucket.objects.filter(Prefix=folder_prefix) if obj.key.endswith('.png')]

    def __len__(self):
        return len(self.image_keys)

    def __getitem__(self, idx):
        img_key = self.image_keys[idx]
        img_obj = self.bucket.Object(img_key)
        img_data = img_obj.get()['Body'].read()
        img = Image.open(io.BytesIO(img_data)).convert('L')  # Convert to grayscale

        if self.transform:
            img = self.transform(img)

        return img

# Define the transform
transform = transforms.Compose([
    transforms.Resize((200, 160)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Create the dataset and dataloader
dataset = S3ImageDataset(bucket, folder_prefix, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Print the new image sizes
for i, img in enumerate(dataloader):
    print(f"Batch {i+1} - Image size: {img.size()}")
    break  # Print the size of the first batch only