Skip to content

yuyun2000/rkan

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

使用说明

pc端

  • 模型训练

运行train.py将训练一个模型并保存在cp目录下

  • 模型导出

运行export_onnx.py将导出kan.onnx(运行check_op.py可以查看算子及限制)

板端

将rk目录下的文件放至开发板,注意修改其中的设备类型和设备ip,理论上rknntoolkit2的设备均可

  • 模型转换

运行convert_kan.py,将得到kan.rknn模型

  • 模型推理

运行infer.py,将会得到如下输出,对应图片的‘6’(可以运行show_test.py查看测试图片)

  • 输出
[array([[-23.453125  , -11.1796875 , -19.46875   ,  -8.        ,
        -45.6875    , -17.765625  ,  -0.69189453, -20.5       ,
        -34.9375    , -49.        ]], dtype=float32)]

转换说明

只遇到了一个广播相关的问题,具体的代码修改见kan.py-93行 原代码为:

        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
    

修改后为:

        x = x.unsqueeze(-1)
        tempg = grid.unsqueeze(0).expand(x.size(0),-1,-1)
        xtemp = x.expand(-1,-1,tempg.size(-1)-1)
        t1 = xtemp >= grid[:, :-1]
        t2 = xtemp < grid[:, 1:]
        bases = (t1 & t2).to(x.dtype)

About

kan的rknn2部署

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages