Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

网络如何进行微调和迁移学习 #12

Open
mur909 opened this issue May 17, 2020 · 4 comments
Open

网络如何进行微调和迁移学习 #12

mur909 opened this issue May 17, 2020 · 4 comments

Comments

@mur909
Copy link

mur909 commented May 17, 2020

我在一个数据集上训练得到一个权重,我想在另一个数据集上还用这个权重并进行训练,该怎么做呢

@yatengLG
Copy link
Owner

yatengLG commented May 18, 2020

在初始化模型之后,先导入模型参数即可。
以 Retinanet-Pytorch/Demo_train.py 文件为例,按照以下修改即可

`

# 初始化模型
net = RetainNet(cfg)

# 这里先导入你的已经训练好的模型权重文件
net.load_state_dict(torch.load("XXX"))

# 将模型移动到gpu上,cfg.DEVICE.MAINDEVICE定义了模型所使用的主GPU
net.to(cfg.DEVICE.MAINDEVICE)

# 初始化训练器,训练器参数通过cfg进行配置;也可传入参数进行配置,但不建议
trainer = Trainer(cfg)

# 训练器开始在 数据集上训练模型
trainer(net, train_dataset)

`

@mur909
Copy link
Author

mur909 commented May 18, 2020

已试成功,非常感谢。那如果训练集类别个数与原先训练的权重类别不一样,那如何抑制后面的几层

@yatengLG
Copy link
Owner

由于你类别数都不一样,必须更改模型结构。

你可以这样更改,先初始化模型 -> 导入参数 -> 然后更改模型中的predictor 结构,也就是最后的分类和回归的几层。

这样的做法流程是:
`

# 初始化模型
net = RetainNet(cfg)

# 这里先导入你的已经训练好的模型权重文件
net.load_state_dict(torch.load("XXX"))

# 导入 predictor 结构,并进行初始化。
from Model.struct import predictor  
# 将模型中的predictor替换为新的predictor
net.predictor = predictor(num_anchors=anchor数(如果没有更改anchor,那就是9), num_classes=新类别数)

# 将模型移动到gpu上,cfg.DEVICE.MAINDEVICE定义了模型所使用的主GPU
net.to(cfg.DEVICE.MAINDEVICE)

# 初始化训练器,训练器参数通过cfg进行配置;也可传入参数进行配置,但不建议
trainer = Trainer(cfg)

# 训练器开始在 数据集上训练模型
trainer(net, train_dataset)

`

这样做,模型前面的参数均是你训练好的参数,只有predictor 是随机初始化的,你可以使用这种方法进行迁移。

@mur909
Copy link
Author

mur909 commented May 18, 2020

我原本的思路是只改predictor中分类的最后一个卷积层然后训练,但是现在想想那样可能效果并不好,还是应该像你这样,重新训练head部分。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants