# Export RAFT Model

In [None]:
import torch
from torchvision.models.optical_flow import (
    raft_small,
    Raft_Small_Weights,
)


class RaftDefuck(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = raft_small(Raft_Small_Weights.C_T_V2)

    def forward(self, frames):
        return self.model(frames[0], frames[1])

In [None]:
import torch
import zipfile
import respiration.utils as utils

model = RaftDefuck()
model = torch.jit.script(model)

# Save the model
filename = utils.file_path("assets/raft_defuck.pt")
torch.jit.save(model, filename)

with zipfile.ZipFile(filename, "a", compression=zipfile.ZIP_DEFLATED) as zipf:
    # Add the version file
    with zipf.open("raft_defuck/version", "w") as f:
        f.write("1".encode("utf-8"))
    zipf.close()

In [None]:
# Read the assets/test.mp4 file
import cv2
import numpy as np

video = cv2.VideoCapture(utils.file_path("assets/test.mp4"))
frames = []
while True:
    ret, frame = video.read()
    if not ret:
        break

    # Convert the frame to RGB
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frames.append(frame)
video.release()
frames = np.array(frames)

In [None]:
# Show the first frame
import matplotlib.pyplot as plt

plt.imshow(frames[0])

In [None]:
len(frames)

In [None]:
frames = frames[:6]

In [None]:
# Run a forward pass on the model
import torch

model = torch.jit.load(utils.file_path("assets/raft_small.pt"))
model = model.eval()

# Convert the frames to tensor
frames = torch.from_numpy(frames).permute(0, 3, 1, 2).float()

In [None]:
frames.shape

In [None]:
stack = torch.stack([frames[:-1], frames[1:]], dim=0)
stack.shape

In [None]:
# Squeeze the stack into a single dimension
frames_bin = stack.reshape(-1)

# Print type and size of the stack
print(frames_bin.dtype, frames_bin.size())

In [None]:
# Write stack_x to a file in binary format
filename = utils.file_path("assets/frames.bin")
frames_bin.numpy().tofile(filename)

In [None]:
import struct

# Check if the file was written correctly
with open(filename, "rb") as file:
    for inx in range(10):
        data = file.read(4)
        print(frames_bin[inx])
        print(struct.unpack("f", data))

In [None]:
# Run the model
with torch.no_grad():
    flow = model(stack)

In [None]:
frames.shape

In [None]:
flow = np.array(flow)
flow.shape