In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Generate synthetic CT-scan data (batches, slices, RGB) and associated segmentation masks
torch.manual_seed(42)
batch = 100
num_slices = 10
channels = 3
width = 256
height = 256

ct_images = torch.randn(size=(batch, num_slices, channels, width, height))
segmentation_masks = (torch.randn(size=(batch, num_slices, 1, width, height))>0).float()

print(f"CT images (train examples) shape: {ct_images.shape}")
print(f"Segmentation binary masks (labels) shape: {segmentation_masks.shape}")

# Define the MedCNN class and its forward method
class MedCNN(nn.Module):
    def __init__(self, backbone, out_channel=1):
        super(MedCNN, self).__init__()
        self.backbone = backbone

        #Downsample
        self.conv1 = nn.Conv3d(512, 64, kernel_size=(3, 3, 3), padding=1)
        self.conv2 = nn.Conv3d(64, 64, kernel_size=(3, 3, 3), padding=1)

        #Upsample
        self.conv_transpose1 = nn.ConvTranspose3d(64, 32, kernel_size=(1, 4, 4), stride=(1, 4, 4))
        self.conv_transpose2 = nn.ConvTranspose3d(32, 16, kernel_size=(1, 8, 8), stride=(1, 8, 8))

        #Final convolution layer from 16 to 1 channel
        self.final_conv = nn.Conv3d(16, out_channel, kernel_size=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        b, d, c, w, h = x.size() #Input size: [B, D, C, W, H]
        print(f"Input shape [B, D, C, W, H]: {b, d, c, w, h}")

        x = x.view(b*d, c, w, h) #Input to Resent 2DConv layers [B*D, C, W, H]
        features = self.backbone(x)
        print(f"ResNet output shape[B*D, C, W, H]: {features.shape}")

        _, new_c, new_w, new_h = features.size()
        x = features.view(b, d, new_c, new_w, new_h) #[B, D, C, W, H]
        x = torch.permute(x, (0, 2, 1, 3, 4)) #rearrange for 3DConv layers [B, C, D, W, H]
        print(f"Reshape Resnet output for 3DConv #1 [B, C, D, W, H]: {x.shape}")

        #Downsampling
        x = self.relu(self.conv1(x))
        print(f"Output shape 3D Conv #1: {x.shape}")
        x = self.relu(self.conv2(x))
        print(f"Output shape 3D Conv #2: {x.shape}")

        #Upsampling
        x = self.relu(self.conv_transpose1(x))
        print(f"Output shape 3D Transposed Conv #1: {x.shape}")
        x = self.relu(self.conv_transpose2(x))
        print(f"Output shape 3D Transposed Conv #2: {x.shape}")

        #final segmentation
        x = torch.sigmoid(self.final_conv(x))
        print(f"Final shape: {x.shape}")

        return x

def compute_dice_loss(pred, labels, eps=1e-8):
    '''
    Args
    pred: [B, D, 1, W, H]
    labels: [B, D, 1, W, H]

    Returns
    dice_loss: [B, D, 1, W, H]
    '''
    numerator = 2*torch.sum(pred*labels)
    denominator = torch.sum(pred) + torch.sum(labels) + eps
    return numerator/denominator

resnet_model = torchvision.models.resnet18(pretrained=True)
resnet_model = nn.Sequential(*list(resnet_model.children())[:-2])

model = MedCNN(backbone=resnet_model)

optimizer = optim.Adam(model.parameters(), lr=0.01)

epochs = 5
for epoch in range(epochs):
    optimizer.zero_grad()
    pred = model(ct_images)
    loss = compute_dice_loss(pred, segmentation_masks)
    loss.backward()
    optimizer.step()
    print(f"Loss at epoch {epoch}: {loss}")

CT images (train examples) shape: torch.Size([100, 10, 3, 256, 256])
Segmentation binary masks (labels) shape: torch.Size([100, 10, 1, 256, 256])


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 335MB/s]


Input shape [B, D, C, W, H]: (100, 10, 3, 256, 256)
ResNet output shape[B*D, C, W, H]: torch.Size([1000, 512, 8, 8])
Reshape Resnet output for 3DConv #1 [B, C, D, W, H]: torch.Size([100, 512, 10, 8, 8])
Output shape 3D Conv #1: torch.Size([100, 64, 10, 8, 8])
Output shape 3D Conv #2: torch.Size([100, 64, 10, 8, 8])
Output shape 3D Transposed Conv #1: torch.Size([100, 32, 10, 32, 32])
Output shape 3D Transposed Conv #2: torch.Size([100, 16, 10, 256, 256])
Final shape: torch.Size([100, 1, 10, 256, 256])
Loss at epoch 0: 4.883825302124023
Input shape [B, D, C, W, H]: (100, 10, 3, 256, 256)
ResNet output shape[B*D, C, W, H]: torch.Size([1000, 512, 8, 8])
Reshape Resnet output for 3DConv #1 [B, C, D, W, H]: torch.Size([100, 512, 10, 8, 8])
Output shape 3D Conv #1: torch.Size([100, 64, 10, 8, 8])
Output shape 3D Conv #2: torch.Size([100, 64, 10, 8, 8])
Output shape 3D Transposed Conv #1: torch.Size([100, 32, 10, 32, 32])
Output shape 3D Transposed Conv #2: torch.Size([100, 16, 10, 256, 256])

KeyboardInterrupt: 

In [None]:
## Strong LLM
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
from flax.training import train_state

# -------------------------------
# Synthetic Data Generation
# -------------------------------
batch = 100
num_slices = 10
channels = 3
width = 256
height = 256

# Create PRNG keys
key = jax.random.PRNGKey(42)
key, subkey1, subkey2 = jax.random.split(key, 3)

# Generate synthetic CT images and binary segmentation masks
ct_images = jax.random.normal(subkey1, (batch, num_slices, channels, width, height))
segmentation_masks = (jax.random.normal(subkey2, (batch, num_slices, 1, width, height)) > 0).astype(jnp.float32)

print("CT images (train examples) shape:", ct_images.shape)
print("Segmentation binary masks (labels) shape:", segmentation_masks.shape)

# -------------------------------
# Dummy ResNet18 Backbone Definition
# -------------------------------
class DummyResNet18(nn.Module):
    """
    Mimics a truncated ResNet18:
    Given an input of shape [B*D, 3, 256, 256],
    returns a feature map of shape [B*D, 512, new_w, new_h],
    where new_w and new_h are ~8 (downsampling via stride=32).
    """
    @nn.compact
    def __call__(self, x):
        # Using a single convolution with stride 32 to simulate the downsampling:
        x = nn.Conv(features=512,
                    kernel_size=(7, 7),
                    strides=(32, 32),
                    padding='SAME',
                    dimension_numbers=('NCHW', 'OIHW', 'NCHW'))(x)
        x = nn.relu(x)
        return x

# -------------------------------
# MedCNN Model Definition in Flax
# -------------------------------
class MedCNN(nn.Module):
    backbone: nn.Module
    out_channel: int = 1

    @nn.compact
    def __call__(self, x):
        # x shape: [B, D, C, W, H]
        b, d, c, w, h = x.shape
        print("Input shape [B, D, C, W, H]:", (b, d, c, w, h))

        # Reshape for backbone (2D convolutions): [B*D, C, W, H]
        x = x.reshape((b * d, c, w, h))
        features = self.backbone(x)
        print("Backbone (ResNet) output shape [B*D, C, W, H]:", features.shape)

        # Get new dimensions and reshape back: [B, D, new_c, new_w, new_h]
        _, new_c, new_w, new_h = features.shape
        x = features.reshape((b, d, new_c, new_w, new_h))
        # Permute to [B, new_c, D, new_w, new_h] for 3D convolutions
        x = jnp.transpose(x, (0, 2, 1, 3, 4))
        print("Reshaped for 3D conv [B, C, D, W, H]:", x.shape)

        # Define dimension numbers for 3D convolutions (NCDHW ordering)
        dim3d = ("NCDHW", "OIDHW", "NCDHW")

        # Downsampling 3D convolutions:
        x = nn.Conv(features=64, kernel_size=(3, 3, 3), padding='SAME', dimension_numbers=dim3d)(x)
        x = nn.relu(x)
        print("After 3D Conv #1:", x.shape)

        x = nn.Conv(features=64, kernel_size=(3, 3, 3), padding='SAME', dimension_numbers=dim3d)(x)
        x = nn.relu(x)
        print("After 3D Conv #2:", x.shape)

        # Upsampling 3D transposed convolutions:
        x = nn.ConvTranspose(features=32, kernel_size=(1, 4, 4), strides=(1, 4, 4),
                             padding='SAME', dimension_numbers=dim3d)(x)
        x = nn.relu(x)
        print("After 3D Transposed Conv #1:", x.shape)

        x = nn.ConvTranspose(features=16, kernel_size=(1, 8, 8), strides=(1, 8, 8),
                             padding='SAME', dimension_numbers=dim3d)(x)
        x = nn.relu(x)
        print("After 3D Transposed Conv #2:", x.shape)

        # Final segmentation layer (from 16 to 1 channel)
        x = nn.Conv(features=self.out_channel, kernel_size=(1, 1, 1), padding='SAME', dimension_numbers=dim3d)(x)
        x = jax.nn.sigmoid(x)
        print("Final output shape:", x.shape)
        return x

# -------------------------------
# Dice Loss Function
# -------------------------------
def compute_dice_loss(pred, labels, eps=1e-8):
    """
    Args:
      pred: [B, D, 1, W, H]
      labels: [B, D, 1, W, H]

    Returns:
      Dice coefficient (scalar)
    """
    numerator = 2 * jnp.sum(pred * labels)
    denominator = jnp.sum(pred) + jnp.sum(labels) + eps
    return numerator / denominator

# -------------------------------
# Set Up Model and Training State
# -------------------------------
# Instantiate the backbone and MedCNN model.
backbone = DummyResNet18()
model = MedCNN(backbone=backbone)

# Create a training state using Flax's TrainState and Optax's Adam optimizer.
class TrainState(train_state.TrainState):
    pass

# Initialize model parameters using the synthetic data.
rng = jax.random.PRNGKey(0)
params = model.init(rng, ct_images)
tx = optax.adam(learning_rate=0.01)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

# -------------------------------
# Training Loop
# -------------------------------
epochs = 5
for epoch in range(epochs):
    def loss_fn(params):
        pred = model.apply(params, ct_images)
        loss = compute_dice_loss(pred, segmentation_masks)
        return loss, pred

    # Compute loss and gradients.
    (loss, pred), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    # Update the model parameters.
    state = state.apply_gradients(grads=grads)
    print(f"Loss at epoch {epoch}: {loss}")


In [None]:
#Weak LLM
import jax
import jax.numpy as jnp
from flax import linen as nn
import numpy as np

# Generate synthetic CT-scan data (batches, slices, RGB) and associated segmentation masks
def generate_synthetic_data(batch_size, num_slices, image_shape):
    # Example function body for generating synthetic data
    ct_scans = jax.random.normal(jax.random.PRNGKey(0), (batch_size, num_slices) + image_shape)
    segmentation_masks = jax.random.randint(jax.random.PRNGKey(1), shape=(batch_size, num_slices), minval=0, maxval=2)
    return ct_scans, segmentation_masks

# Define a loss function
def loss_fn(params, ct_scans, segmentation_masks):
    # Placeholder logic for a loss function
    predictions = dummy_model(params, ct_scans)  # Assume dummy_model is defined elsewhere
    return jnp.mean((predictions - segmentation_masks) ** 2)

# Define a training step function using JAX's jitting
@jax.jit
def train_step(params, ct_scans, segmentation_masks, prng_key):
    loss_value = loss_fn(params, ct_scans, segmentation_masks)
    return loss_value

# Vectorized training function to avoid Python loops // MODIFIED
def train(params, segmentation_masks):
    # Create a PRNG key
    prng_key = jax.random.PRNGKey(2)

    # Generate synthetic data
    ct_scans, _ = generate_synthetic_data(params['batch_size'], params['num_slices'], params['image_shape'])

    # Forward pass through the training function
    loss_value = train_step(params, ct_scans, segmentation_masks, prng_key) // MODIFIED

    print(f'Loss at epoch: {loss_value}')  # Adjusted to show loss for the single epoch

# Entry point of the program
if __name__ == "__main__":
    try:
        # Example parameter initialization
        params = {
            'batch_size': 16,
            'num_slices': 10,
            'image_shape': (224, 224, 3)
        }
        segmentation_masks = np.random.randint(0, 2, size=(params['batch_size'], params['num_slices']))  # Dummy masks for illustration
        train(params, segmentation_masks)
        print("Training completed successfully.")  # Placeholder for actual logic
    except Exception as e:
        print(f"An error occurred during training: {e}")

An error occurred during training: name 'dummy_model' is not defined


In [None]:
"""Error Code####
main()  # Run the main function
evaluate_model(state, images, labels)
logits = state.apply_fn({'params': state.params}, images)

##Error:
# ApplyScopeInvalidVariablesStructureError
expect the `variables` (first argument) passed to apply() to be a dict with the structure {"params": ...},
but got a dict with an extra params layer, i.e.  {"params": {"params": ... } }.

## Fix guide
You should instead pass in your dict's ["params"].

# Correct Code
logits = state.apply_fn(state.params, images)
"""


"""Error Code####
    predictions = dummy_model(params, ct_scans)

  Error:
  An error occurred during training: name 'dummy_model' is not defined

  Fix guide:
  change dummy_model to model

  # Correct Code
  predictions = model(params, ct_scans)

"""

"""
Error Code:
@jax.jit
def train_step(model_params, model, ct_scans, segmentation_masks, prng_key):

Error:
Error interpreting argument to <function train_step at 0x7843144d6d40>
as an abstract array.
The problematic value is of type <class '__main__.SimpleModel'> and was passed
to the function at path model.
This typically means that a jit-wrapped function was called
with a non-array argument,
and this argument was not marked as static using the static_argnums or
static_argnames parameters of jax.jit.

Fix guide:
modify the @jax.jit decorator to specify that model is a static argument
using the static_argnums parameter. Since model is the second argument
(index 1 in Python’s zero-based indexing), you’ll set static_argnums=(1,).
This tells JAX to treat model as a fixed object that doesn’t change
during JIT compilation, while the other arguments remain traceable arrays.

Corrext Code
@jax.jit(static_argnums=(1,))
def train_step(model_params, model, ct_scans, segmentation_masks, prng_key):
"""

"""
##Error Code:
def compute_dice_loss(pred, labels, eps=1e-8):
    numerator = 2 * jnp.sum(pred * labels)
    denominator = jnp.sum(pred) + jnp.sum(labels) + eps
    return numerator / denominator
## Error:
TypeError: mul got incompatible shapes for broadcasting: (100, 0, 320, 256, 1),
(100, 10, 1, 256, 256).

## Fix guide


"""






In [1]:
## Fixed code
# !pip install numpyro==0.13.1
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
import numpy as np

# Set random seed for reproducibility (equivalent to torch.manual_seed(42))
rng = jax.random.PRNGKey(42)

# Generate synthetic CT-scan data (batches, slices, channels, width, height)
batch = 5
num_slices = 10
channels = 3
width = 256
height = 256

def generate_synthetic_data(rng, batch, num_slices, channels, width, height):
    rng_data, rng_masks = jax.random.split(rng)
    ct_images = jax.random.normal(rng_data, (batch, num_slices, channels, width, height))
    segmentation_masks = (jax.random.normal(rng_masks, (batch, num_slices, 1, width, height)) > 0).astype(jnp.float32)
    return ct_images, segmentation_masks

ct_images, segmentation_masks = generate_synthetic_data(rng, batch, num_slices, channels, width, height)
print(f"CT images (train examples) shape: {ct_images.shape}")
print(f"Segmentation binary masks (labels) shape: {segmentation_masks.shape}")

# Define the MedCNN class in Flax
class MedCNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        b, d, c, w, h = x.shape  # Input size: [B, D, C, W, H]
        print(f"Input shape [B, D, C, W, H]: {(b, d, c, w, h)}")

        x = x.reshape(b * d, c, w, h)  # [B*D, C, W, H]
        x = jnp.transpose(x, (0, 2, 3, 1))  # [B*D, W, H, C] = (1000, 256, 256, 3) for Flax NHWC
        x = nn.Conv(features=512, kernel_size=(3, 3), padding="SAME")(x)  # Simplified ResNet-like layer
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(32, 32), strides=(32, 32), padding="VALID")  # Downsample to 8x8
        x = jnp.transpose(x, (0, 3, 1, 2))  # [B*D, 512, 8, 8] to match PyTorch NCHW
        print(f"ResNet-like output shape [B*D, C, W, H]: {x.shape}")

        # Reshape back for 3D conv layers
        _, new_c, new_w, new_h = x.shape
        x = x.reshape(b, d, new_c, new_w, new_h)  # [B, D, C, W, H]
        x = jnp.transpose(x, (0, 2, 1, 3, 4))  # [B, C, D, W, H]
        print(f"Reshape ResNet output for 3DConv #1 [B, C, D, W, H]: {x.shape}")

        # Downsampling
        x = nn.Conv(features=64, kernel_size=(3, 3, 3), padding="SAME")(x)
        x = nn.relu(x)
        print(f"Output shape 3D Conv #1: {x.shape}")
        x = nn.Conv(features=64, kernel_size=(3, 3, 3), padding="SAME")(x)
        x = nn.relu(x)
        print(f"Output shape 3D Conv #2: {x.shape}")

        # Upsampling
        x = nn.ConvTranspose(features=32, kernel_size=(1, 4, 4), strides=(1, 4, 4), padding="VALID")(x)
        x = nn.relu(x)
        print(f"Output shape 3D Transposed Conv #1: {x.shape}")
        x = nn.ConvTranspose(features=16, kernel_size=(1, 8, 8), strides=(1, 8, 8), padding="VALID")(x)
        x = nn.relu(x)
        print(f"Output shape 3D Transposed Conv #2: {x.shape}")

        # Final segmentation
        x = nn.Conv(features=1, kernel_size=(1, 1, 1))(x)
        x = jax.nn.sigmoid(x)
        print(f"Final shape: {x.shape}")

        return x

# Dice loss function
def compute_dice_loss(pred, labels, eps=1e-8):
    numerator = 2 * jnp.sum(pred * labels)
    denominator = jnp.sum(pred) + jnp.sum(labels) + eps
    print(f"Dice numerator: {numerator}")
    print(f"Dice denominator: {denominator}")

    return numerator / denominator

# Training step with JIT
@jax.jit
def train_step(params, state, ct_images, segmentation_masks):
    def loss_fn(params):
        pred = model.apply({'params': params}, ct_images)
        dice = compute_dice_loss(pred, segmentation_masks)
        return 1 - dice  # Convert to loss (1 - Dice score)

    loss, grads = jax.value_and_grad(loss_fn)(params)
    updates, state = optimizer.update(grads, state, params)
    params = optax.apply_updates(params, updates)
    return params, state, loss

# Initialize model and optimizer
model = MedCNN()
rng_init, rng_train = jax.random.split(rng)
dummy_input = jnp.ones((batch, num_slices, channels, width, height))
params = model.init(rng_init, dummy_input)['params']
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(params)


# Training loop
epochs = 5
for epoch in range(epochs):
    params, opt_state, loss = train_step(params, opt_state, ct_images, segmentation_masks)
    print(f"Loss at epoch {epoch}: {loss}")

print("Training completed successfully.")

CT images (train examples) shape: (5, 10, 3, 256, 256)
Segmentation binary masks (labels) shape: (5, 10, 1, 256, 256)
Input shape [B, D, C, W, H]: (5, 10, 3, 256, 256)
ResNet-like output shape [B*D, C, W, H]: (50, 512, 8, 8)
Reshape ResNet output for 3DConv #1 [B, C, D, W, H]: (5, 512, 10, 8, 8)
Output shape 3D Conv #1: (5, 512, 10, 8, 64)
Output shape 3D Conv #2: (5, 512, 10, 8, 64)
Output shape 3D Transposed Conv #1: (5, 512, 40, 32, 32)
Output shape 3D Transposed Conv #2: (5, 512, 320, 256, 16)
Final shape: (5, 512, 320, 256, 1)


KeyboardInterrupt: 