Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to convert rife's onnx model? #56

Closed
jhl13 opened this issue Jan 16, 2024 · 7 comments
Closed

How to convert rife's onnx model? #56

jhl13 opened this issue Jan 16, 2024 · 7 comments

Comments

@jhl13
Copy link

jhl13 commented Jan 16, 2024

I found the converted rife onnx model at https://github.com/styler00dollar/VSGAN-tensorrt-docker/releases, but the onnx model input here is slightly different from the input of Practical-RIFE. It seems that the timestep is integrated into the input. Can you share the code for rife onnx model conversion? I didn't find this code in the project.

@styler00dollar
Copy link
Owner

styler00dollar commented Jan 16, 2024

I write a new export script every time the architecture changes, which is nearly every rife version. I do some adjustments so I can properly use it with core.trt from mlrt, since that only allows one input. The first 6 channels are the 2 input images, then it's one channel for timestep and one channel for scale. You can see it in rife_trt.py.

The file will change depending on the version, but with 4.12 as an example:

class IFBlock(nn.Module):
    def __init__(self, in_planes, c=64):
        ...
    def forward(self, x, flow=None, scale=1):
        try:
            scale = scale.item()
        except:
            pass
        ...

class IFNet(nn.Module):
    def __init__(self):
        ...

    def forward(self, input, fastmode=True, ensemble=False):
        input = torch.clamp(input, 0, 1)
        img0 = input[:, :3]
        img1 = input[:, 3:6]
        timestep = input[:, 6:7][0][0][0][0]
        scale = input[:, 7:8][0][0][0][0]
        scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
        n, c, h, w = img0.shape
        ph = ((h - 1) // 64 + 1) * 64
        pw = ((w - 1) // 64 + 1) * 64
        padding = (0, pw - w, 0, ph - h)
        img0 = F.pad(img0, padding)
        img1 = F.pad(img1, padding)
        x = torch.cat((img0, img1), 1)
        timestep = (x[:, :1].clone() * 0 + 1) * timestep
        timestep = timestep.float()
        ...
        return merged[3][:, :, :h, :w]

def convert(param, rank=-1):
    if rank == -1:
        return {k.replace("module.", ""): v for k, v in param.items() if "module." in k}
    else:
        return param

model = IFNet()
model.eval()
state_dict = convert(torch.load("flownet.pkl", map_location="cpu"))

model.load_state_dict(state_dict, strict=False)
# torch.save(model.state_dict(), "resaved_rife.pth")

with torch.inference_mode():
    dynamic_axes = {
        "input": {0: "batch_size", 2: "width", 3: "height"},
        "output": {0: "batch_size", 2: "width", 3: "height"},
    }

    torch.onnx.export(
        model.cuda(),
        torch.cat([torch.rand(1, 6, 256, 256), torch.ones(1, 2, 256, 256)], 1).cuda(),
        "rife412_fastTrue_ensembleFalse_op18_clamp.onnx",
        verbose=False,
        opset_version=18,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes=dynamic_axes,
    )

For fp16 you will need further adjustments:

class IFBlock(nn.Module):
    def forward(self, x, flow=None, scale=1):
        ...
        feat = self.conv0(x.half())
        ...

class IFNet(nn.Module):
    def forward(self, input, fastmode=True, ensemble=False):
        ...
        timestep = timestep.half()
        ...

@jhl13
Copy link
Author

jhl13 commented Jan 17, 2024

Thank you, I roughly understand the onnx conversion process.

@Djdefrag
Copy link

Djdefrag commented Mar 3, 2024

@styler00dollar Hi my friend,

sorry to bother, can you share the full modified code of RIFE please?

And another question, i see this dummy input for onnx export

torch.cat([torch.rand(1, 6, 256, 256), torch.ones(1, 2, 256, 256)], 1).cuda()

and when i have to use the onnx model for inference, what are the input. for example:

  • torch.rand(1, 6, 256, 256) -> should be the 2 images combined with torch.cat ?
  • and torch.ones(1, 2, 256, 256) -> what is this? is this a fixed value?

@styler00dollar
Copy link
Owner

#56 (comment)

@Djdefrag
Copy link

Hi everyone, sorry to bother again.

once I convert RIFE to onnx, I then have to create the input via numpy for inference:

def concatenate_frames(
        frame_1: numpy_ndarray, 
        frame_2: numpy_ndarray, 
    ) -> numpy_ndarray:

    height, width = get_image_resolution(frame_1)

    input_images = numpy_concatenate((frame_1, frame_2), axis=2)
    timestep = numpy_ones((height, width, 1))
    scale    = numpy_ones((height, width, 1))
    result = numpy_concatenate((input_images, timestep, scale), axis=2)

is that correct?

@styler00dollar
Copy link
Owner

styler00dollar commented Apr 3, 2024

The input shape with the onnx code i showed is 1 (or different batch), 8, height, width.

@lgthappy
Copy link

Hi everyone, do you have the source code for onnx conversion that worked? I've been trying to follow the steps for ages and never got it right.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants