Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine-Tuning a U-Net trained using MONAI lib #387

Open
antoniocandito opened this issue Jun 12, 2024 · 0 comments
Open

Fine-Tuning a U-Net trained using MONAI lib #387

antoniocandito opened this issue Jun 12, 2024 · 0 comments

Comments

@antoniocandito
Copy link

Description:

Hi,

I would like to prune a model developed using the MONAI library, but I am facing the following error:

Code:

Define the model architecture

network = UNet(
spatial_dims=3,
in_channels=2,
out_channels=13,
channels=(32, 64, 128, 256, 512),
strides=(2, 2, 2, 2),
norm="batch",
num_res_units=2,
dropout=0.2
).to(device)

Load the pre-trained model weights

state_dict = torch.load(pre_trained_model_path.pth)

PRUNING
importance criterion for parameter selection

example_inputs = torch.randn(1, 2, 256, 256, 240).to(device)
imp = tp.importance.MagnitudeImportance(p=2, group_reduction='mean')

Pruner initialization

print(network)
iterative_steps = 5 # You can prune your model to the target pruning ratio iteratively.
pruner = tp.pruner.MagnitudePruner(
network,
example_inputs,
global_pruning=False, # If False, a uniform ratio will be assigned to different layers.
importance=imp, # importance criterion for parameter selection
iterative_steps=self.max_epochs, # number of iterations to achieve target ratio
pruning_ratio=0.4
)

base_macs, base_nparams = tp.utils.count_ops_and_params(network, example_inputs)

for epoch in range(self.max_epochs):

# 3. the pruner.step will remove some channels from the model with least importance
pruner.step()

Error:

File "...\lib\site-packages\torch_pruning\dependency.py", line 196, in prune
dep(idxs)
File "...lib\site-packages\torch_pruning\dependency.py", line 109, in call
result = self.handler(self.target.module, idxs)
File "...lib\site-packages\torch_pruning\ops.py", line 111, in prune_out_channels
offsets.append(offsets[i] + concat_sizes[i])
TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant