In [15]:
import numpy as np
a = np.load('/home/user/ckwan1/ml/mlsimdata_npy/mlsimdata1/213864.npy')

In [18]:
a[:,:,:,6:9].shape

(32, 32, 32, 3)

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import lightning.pytorch as pl
from periodic_padding import periodic_padding_3d

def crop_tensor(x):
	x = x.narrow(2,1,x.shape[2]-3).narrow(3,1,x.shape[3]-3).narrow(4,1,x.shape[4]-3).contiguous()
	return x

def conv3x3(inplane,outplane, stride=1,padding=0):
	return nn.Conv3d(inplane,outplane,kernel_size=3,stride=stride,padding=padding,bias=True)
    
# Assuming conv3x3 and BasicBlock are defined as in your original code.
class BasicBlock(nn.Module):
	def __init__(self,inplane,outplane,stride = 1):
		super(BasicBlock, self).__init__()
		self.conv1 = conv3x3(inplane,outplane,padding=0,stride=stride)
		self.bn1 = nn.BatchNorm3d(outplane)
		self.relu = nn.ReLU(inplace=True)

	def forward(self,x):
		x = periodic_padding_3d(x,pad=(1,1,1,1,1,1))
		out = self.conv1(x)
		out = self.bn1(out)
		out = self.relu(out)
		return out


class Lpt2NbodyNet(nn.Module):
    def __init__(self, block):
        super(Lpt2NbodyNet, self).__init__()
        self.layer1 = self._make_layer(block, 3, 64, blocks=2, stride=1)
        self.layer2 = self._make_layer(block, 64, 128, blocks=1, stride=2)
        self.layer3 = self._make_layer(block, 128, 128, blocks=2, stride=1)
        self.layer4 = self._make_layer(block, 128, 256, blocks=1, stride=2)
        self.layer5 = self._make_layer(block, 256, 256, blocks=2, stride=1)
        self.deconv1 = nn.ConvTranspose3d(256, 128, 3, stride=2, padding=0)
        self.deconv_batchnorm1 = nn.BatchNorm3d(num_features=128, momentum=0.1)
        self.layer6 = self._make_layer(block, 256, 128, blocks=2, stride=1)
        self.deconv2 = nn.ConvTranspose3d(128, 64, 3, stride=2, padding=0)
        self.deconv_batchnorm2 = nn.BatchNorm3d(num_features=64, momentum=0.1)
        self.layer7 = self._make_layer(block, 128, 64, blocks=2, stride=1)
        self.deconv4 = nn.ConvTranspose3d(64, 3, 1, stride=1, padding=0)

    def _make_layer(self, block, inplanes, outplanes, blocks, stride=1):
        layers = []
        for _ in range(blocks):
            layers.append(block(inplanes, outplanes, stride=stride))
            inplanes = outplanes
        return nn.Sequential(*layers)

    def forward(self, x):
        x1 = self.layer1(x)
        x = self.layer2(x1)
        x2 = self.layer3(x)
        x = self.layer4(x2)
        x = self.layer5(x)
        x = periodic_padding_3d(x, pad=(0, 1, 0, 1, 0, 1))
        x = nn.functional.relu(self.deconv_batchnorm1(crop_tensor(self.deconv1(x))), inplace=True)
        x = torch.cat((x, x2), dim=1)
        x = self.layer6(x)
        x = periodic_padding_3d(x, pad=(0, 1, 0, 1, 0, 1))
        x = nn.functional.relu(self.deconv_batchnorm2(crop_tensor(self.deconv2(x))), inplace=True)
        x = torch.cat((x, x1), dim=1)
        x = self.layer7(x)
        x = self.deconv4(x)
        return x

class UNet3D(nn.Module):  
    def __init__(self, block, num_layers=2, base_filters=64, blocks_per_layer=2):
        super(UNet3D, self).__init__()
        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        
        # Encoder path
        init_channels = 3
        out_channels = base_filters
        self.init_conv = self._make_layer(block, init_channels, out_channels, blocks=blocks_per_layer, stride=1)
        for _ in range(num_layers):
            self.encoders.append(self._make_layer(block, out_channels, out_channels*2, blocks=1, stride=2))
            self.encoders.append(self._make_layer(block, out_channels*2, out_channels*2, blocks=blocks_per_layer, stride=1))
            out_channels *= 2

        # Decoder path
        for _ in range(num_layers):
            self.decoders.append(nn.ConvTranspose3d(out_channels, out_channels//2, kernel_size=3, stride=2, padding=0))
            self.decoders.append(self._make_layer(block, out_channels, out_channels//2, blocks=blocks_per_layer, stride=1))
            out_channels //= 2

        self.final_conv = nn.ConvTranspose3d(out_channels, 3, 1, stride=1, padding=0)

        # Predefine BatchNorm3d and ReLU layers for each decoder step
        self.batch_norms = nn.ModuleList()
        self.relu = nn.ReLU(inplace=True)
        for i in range(num_layers):
            self.batch_norms.insert(0, nn.BatchNorm3d(base_filters * (2 ** i)))  # Adjust channels accordingly

    def _make_layer(self, block, inplanes, outplanes, blocks, stride=1):
        layers = []
        for _ in range(blocks):
            layers.append(block(inplanes, outplanes, stride=stride))
            inplanes = outplanes
        return nn.Sequential(*layers)

    def forward(self, x):
        encoder_outputs = []

        x = self.init_conv(x)
        encoder_outputs.append(x)
        
        # Encoding path
        for i in range(0, len(self.encoders), 2):
            x = self.encoders[i](x)  # Compression layer
            x = self.encoders[i + 1](x)  # Non-compression layer
            encoder_outputs.append(x)

        # Decoding path
        for i in range(0, len(self.decoders), 2):
            x = periodic_padding_3d(x, pad=(0, 1, 0, 1, 0, 1))  # Assuming this is a custom function
            x = self.decoders[i](x)  # Transpose Conv layer
            x = crop_tensor(x)  # Assuming this is a custom function to crop the tensor
            
            # Use the pre-defined BatchNorm3d and ReLU layers
            x = self.batch_norms[i // 2](x)  # BatchNorm
            x = nn.ReLU(inplace=True)(x)  # ReLU
            
            # Skip connection with encoder outputs
            x = torch.cat((x, encoder_outputs[len(encoder_outputs)-2-i//2]), dim=1)  # Skip connection
            
            x = self.decoders[i + 1](x)  # Non-compression layer

        # Final 1x1 Conv
        x = self.final_conv(x)
        return x

In [6]:
import torch
import torch.nn as nn

class UNet3D(nn.Module):
    def __init__(self, block, num_layers=2, base_channels=64):
        super(UNet3D, self).__init__()
        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        
        # Encoder path
        init_channels = 3
        out_channels = base_channels
        self.init_conv = self._make_layer(block, init_channels, out_channels, blocks=2, stride=1)
        print(f"Initial conv output channels: {out_channels}")

        for _ in range(num_layers):
            self.encoders.append(self._make_layer(block, out_channels, out_channels*2, blocks=1, stride=2))
            self.encoders.append(self._make_layer(block, out_channels*2, out_channels*2, blocks=2, stride=1))
            out_channels *= 2
        
        # Decoder path
        for _ in range(num_layers):
            self.decoders.append(nn.ConvTranspose3d(out_channels, out_channels//2, kernel_size=3, stride=2, padding=0))
            self.decoders.append(self._make_layer(block, out_channels, out_channels//2, blocks=2, stride=1))
            out_channels //= 2

        self.final_conv = nn.ConvTranspose3d(out_channels, 3, kernel_size=1, stride=1, padding=0)

    def _make_layer(self, block, inplanes, outplanes, blocks, stride=1):
        layers = []
        for _ in range(blocks):
            layers.append(block(inplanes, outplanes, stride=stride))
            inplanes = outplanes
        return nn.Sequential(*layers)

    def forward(self, x):
        encoder_outputs = []

        x = self.init_conv(x)
        print(f"After initial conv: {x.shape}")
        encoder_outputs.append(x)

        # Encoding path
        for i in range(0, len(self.encoders), 2):
            x = self.encoders[i](x)  # Compression layer
            print(f"After encoder compression {i//2}: {x.shape}")
            x = self.encoders[i + 1](x)  # Non-compression layer
            print(f"After encoder non-compression {i//2}: {x.shape}")
            encoder_outputs.append(x)

        # Decoding path
        for i in range(0, len(self.decoders), 2):
            x = periodic_padding_3d(x, pad=(0, 1, 0, 1, 0, 1))
            print(f"After periodic padding before transpose conv {i//2}: {x.shape}")
            x = self.decoders[i](x)  # Transpose Conv layer
            print(f"After transpose conv {i//2}: {x.shape}")
            x = crop_tensor(x)  # Crop to match dimensions if necessary
            print(f"After cropping {i//2}: {x.shape}")
            
            # BatchNorm and ReLU before concatenation
            x = nn.BatchNorm3d(x.shape[1])(x)  # BatchNorm
            x = nn.ReLU(inplace=True)(x)  # ReLU
            print(f"After BatchNorm and ReLU {i//2}: {x.shape}")

            x = torch.cat((x, encoder_outputs[len(encoder_outputs)-2-i//2]), dim=1)  # Skip connection
            print(f"After concatenation {i//2}: {x.shape}")
            
            x = self.decoders[i + 1](x)  # Non-compression layer
            print(f"After decoder non-compression {i//2}: {x.shape}")

        # Final 1x1 Conv
        x = self.final_conv(x)
        print(f"Final output shape: {x.shape}")
        return x


In [11]:
model = UNet3D(block=BasicBlock, num_layers=3, base_channels=64)

TypeError: UNet3D.__init__() got an unexpected keyword argument 'base_channels'

In [12]:
model = UNet3D(block=BasicBlock, num_layers=3, base_filters=64)

In [13]:
input_shape = (1, 3, 32, 32, 32)  # Change the dimensions as needed
input_tensor = torch.rand(input_shape)  # Create a random input tensor

# Run a forward pass through the model
with torch.no_grad():  # Disable gradient calculation for inference
    output = model(input_tensor)

In [14]:
model

UNet3D(
  (encoders): ModuleList(
    (0): Sequential(
      (0): BasicBlock(
        (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2))
        (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
    )
    (1): Sequential(
      (0): BasicBlock(
        (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1))
        (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
      (1): BasicBlock(
        (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1))
        (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
    )
    (2): Sequential(
      (0): BasicBlock(
        (conv1): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2))
        (bn1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running

In [None]:
%debug 


> [0;32m/tmp/ipykernel_62867/2134185488.py[0m(120)[0;36mforward[0;34m()[0m
[0;32m    118 [0;31m            [0mx[0m [0;34m=[0m [0mnn[0m[0;34m.[0m[0mReLU[0m[0;34m([0m[0minplace[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m([0m[0mx[0m[0;34m)[0m  [0;31m# ReLU[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    119 [0;31m[0;34m[0m[0m
[0m[0;32m--> 120 [0;31m            [0mx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mcat[0m[0;34m([0m[0;34m([0m[0mx[0m[0;34m,[0m [0mencoder_outputs[0m[0;34m[[0m[0;34m-[0m[0;34m([0m[0mi[0m [0;34m+[0m [0;36m1[0m[0;34m)[0m[0;34m][0m[0;34m)[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m  [0;31m# Skip connection[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    121 [0;31m[0;34m[0m[0m
[0m[0;32m    122 [0;31m            [0mx[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdecoders[0m[0;34m[[0m[0;36m2[0m [0;34m*[0m [0mi[0m [0;34m+[0m [0;36m1[0m[0;34m][0m[0;34m([0m[0mx[0m[0

In [5]:
model = Lpt2NbodyNet(block=BasicBlock)
print(model)

Lpt2NbodyNet(
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv3d(3, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1))
      (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1))
      (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2))
      (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1))
      (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (1