-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to convert rife's onnx model? #56
Comments
I write a new export script every time the architecture changes, which is nearly every rife version. I do some adjustments so I can properly use it with The file will change depending on the version, but with 4.12 as an example: class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
...
def forward(self, x, flow=None, scale=1):
try:
scale = scale.item()
except:
pass
...
class IFNet(nn.Module):
def __init__(self):
...
def forward(self, input, fastmode=True, ensemble=False):
input = torch.clamp(input, 0, 1)
img0 = input[:, :3]
img1 = input[:, 3:6]
timestep = input[:, 6:7][0][0][0][0]
scale = input[:, 7:8][0][0][0][0]
scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
n, c, h, w = img0.shape
ph = ((h - 1) // 64 + 1) * 64
pw = ((w - 1) // 64 + 1) * 64
padding = (0, pw - w, 0, ph - h)
img0 = F.pad(img0, padding)
img1 = F.pad(img1, padding)
x = torch.cat((img0, img1), 1)
timestep = (x[:, :1].clone() * 0 + 1) * timestep
timestep = timestep.float()
...
return merged[3][:, :, :h, :w]
def convert(param, rank=-1):
if rank == -1:
return {k.replace("module.", ""): v for k, v in param.items() if "module." in k}
else:
return param
model = IFNet()
model.eval()
state_dict = convert(torch.load("flownet.pkl", map_location="cpu"))
model.load_state_dict(state_dict, strict=False)
# torch.save(model.state_dict(), "resaved_rife.pth")
with torch.inference_mode():
dynamic_axes = {
"input": {0: "batch_size", 2: "width", 3: "height"},
"output": {0: "batch_size", 2: "width", 3: "height"},
}
torch.onnx.export(
model.cuda(),
torch.cat([torch.rand(1, 6, 256, 256), torch.ones(1, 2, 256, 256)], 1).cuda(),
"rife412_fastTrue_ensembleFalse_op18_clamp.onnx",
verbose=False,
opset_version=18,
input_names=["input"],
output_names=["output"],
dynamic_axes=dynamic_axes,
) For fp16 you will need further adjustments: class IFBlock(nn.Module):
def forward(self, x, flow=None, scale=1):
...
feat = self.conv0(x.half())
...
class IFNet(nn.Module):
def forward(self, input, fastmode=True, ensemble=False):
...
timestep = timestep.half()
... |
Thank you, I roughly understand the onnx conversion process. |
@styler00dollar Hi my friend, sorry to bother, can you share the full modified code of RIFE please? And another question, i see this dummy input for onnx export
and when i have to use the onnx model for inference, what are the input. for example:
|
Hi everyone, sorry to bother again. once I convert RIFE to onnx, I then have to create the input via numpy for inference:
is that correct? |
The input shape with the onnx code i showed is |
Hi everyone, do you have the source code for onnx conversion that worked? I've been trying to follow the steps for ages and never got it right. |
I found the converted rife onnx model at https://github.com/styler00dollar/VSGAN-tensorrt-docker/releases, but the onnx model input here is slightly different from the input of Practical-RIFE. It seems that the timestep is integrated into the input. Can you share the code for rife onnx model conversion? I didn't find this code in the project.
The text was updated successfully, but these errors were encountered: