Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there any classification code? #58

Open
JamesQFreeman opened this issue Jan 30, 2021 · 6 comments
Open

Is there any classification code? #58

JamesQFreeman opened this issue Jan 30, 2021 · 6 comments

Comments

@JamesQFreeman
Copy link

No description provided.

@ctlin001
Copy link

What I'm doing is replace the FC layers to my classification layers, although the performance was not good. Happy to discuss more if you are interested in.

@JasperHG90
Copy link

JasperHG90 commented Aug 6, 2021

I used this:

class MedicalNet(nn.Module):

  def __init__(self, path_to_weights, device):
    super(MedicalNet, self).__init__()
    self.model = resnet34(sample_input_D=1, sample_input_H=112, sample_input_W=112, num_seg_classes=2)
    self.model.conv_seg = nn.Sequential(
        nn.AdaptiveMaxPool3d(output_size=(1, 1, 1)),
        nn.Flatten(start_dim=1),
        nn.Dropout(0.1)
    )
    net_dict = self.model.state_dict()
    pretrained_weights = torch.load(path_to_weights, map_location=torch.device(device))
    pretrain_dict = {
        k.replace("module.", ""): v for k, v in pretrained_weights['state_dict'].items() if k.replace("module.", "") in net_dict.keys()
      }
    net_dict.update(pretrain_dict)
    self.model.load_state_dict(net_dict)
    self.fc = nn.Linear(512, 1)

  def forward(self, x):
    features = self.model(x)
    return self.fc(features)

Then:

model = MedicalNet(path_to_weights="pretrain/resnet_34.pth", device=device)

for param_name, param in model.named_parameters():
  if param_name.startswith("conv_seg"):
    param.requires_grad = True
  else:
    param.requires_grad = False

@Batush123
Copy link

Hi @JasperHG90 Do you have the code for training classification? because in train.py there are some parts that connected to the segmentation for example masks etc

@JasperHG90
Copy link

@Batush123 I'm not entirely sure what you're asking for. Are you asking me what my input data & training loop look like?

@wizofe
Copy link

wizofe commented Mar 22, 2022

Hi @JasperHG90, I am looking at a similar problem and I would be glad if you could share your code including the data/training loop, if possible. Thanks!

@alexcla99
Copy link

Hello @JasperHG90, same here, would you be able to share your data/training loop? Thank you very much :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants