In [None]:
########### Infra Files ###############

In [None]:
import cv2
import torch
import numpy as np
from pathlib import Path
from torchvision import transforms, models
import torch.nn as nn
from PIL import Image
from ultralytics import YOLO

# ----- CONFIG -----
SEQ_LEN = 25
IMG_SIZE = 224
DROP_P = 0.3
NUM_CLASSES = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLASS_NAMES = [
    'cover', 'defense', 'flick', 'hook', 'late_cut', 
    'lofted', 'pull', 'square_cut', 'straight', 'sweep'
]

# ----- Extract ground truth from filename -----
video_path = "cover_0001.avi"
gt_label_name = Path(video_path).stem.split("_")[0]  # e.g., 'cover'
gt_label_text = f"GT: {gt_label_name}"

# ----- Model Definitions -----
def get_backbone():
    m = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    m.fc = nn.Identity()
    return m

class CNN_RNN(nn.Module):
    def __init__(self, backbone, rnn_type="LSTM", bidir=False, drop_p=DROP_P):
        super().__init__()
        self.backbone = backbone
        hidden = 256
        rnn_cls = {"LSTM": nn.LSTM, "GRU": nn.GRU}[rnn_type]
        self.rnn = rnn_cls(
            input_size=512, 
            hidden_size=hidden,
            batch_first=True, 
            bidirectional=bidir
        )
        mult = 2 if bidir else 1
        self.dropout = nn.Dropout(drop_p)
        self.head = nn.Linear(hidden * mult, NUM_CLASSES)

    def forward(self, x):
        # x: (B, T, C, H, W)
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)
        feats = self.backbone(x)          # (B*T, feat_dim)
        feats = feats.view(B, T, -1)      # (B, T, feat_dim)
        out, _ = self.rnn(feats)          # (B, T, hidden*mult)
        x_last = out[:, -1, :]            # (B, hidden*mult)
        x_drop = self.dropout(x_last)
        return self.head(x_drop)          # (B, NUM_CLASSES)

# ----- Load Models -----
yolo_model = YOLO("best.pt")
cnn_lstm_model = CNN_RNN(get_backbone(), "LSTM", False).to(DEVICE)
cnn_lstm_model.load_state_dict(
    torch.load("CNN-LSTM_best.pth", map_location=DEVICE)
)
cnn_lstm_model.eval()

# ----- Transform -----
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406], 
        [0.229, 0.224, 0.225]
    )
])

# ----- Prepare Video I/O -----
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

out_vid = cv2.VideoWriter(
    "output_with_preds1.mp4",
    cv2.VideoWriter_fourcc(*'mp4v'),
    fps, 
    (w, h)
)

# ----- Buffers & State -----
roi_seq = []
frame_buffer = []
last_pred_label_text = "Pred: ---"  # initialize with placeholder

while True:
    ret, frame = cap.read()
    if not ret:
        break

    # Store for later annotation
    frame_buffer.append(frame.copy())

    # Run YOLO detection on the current frame
    results = yolo_model(frame, verbose=False)[0]
    bboxes = results.boxes.xyxy.cpu().numpy()

    if len(bboxes) > 0:
        x1, y1, x2, y2 = map(int, bboxes[0])  # first detected ROI
        roi = frame[y1:y2, x1:x2]
        roi_pil = Image.fromarray(cv2.cvtColor(roi, cv2.COLOR_BGR2RGB))
        roi_tensor = transform(roi_pil)
        roi_seq.append(roi_tensor)

    # When we have SEQ_LEN ROIs, run the sequence classifier
    if len(roi_seq) == SEQ_LEN:
        with torch.no_grad():
            input_seq = torch.stack(roi_seq).unsqueeze(0).to(DEVICE)
            logits = cnn_lstm_model(input_seq)
            pred_class = logits.argmax(1).item()
            # Update the last prediction
            last_pred_label_text = f"Pred: {CLASS_NAMES[pred_class]}"

        # Annotate all buffered frames with GT and this prediction
        for f in frame_buffer:
            cv2.putText(
                f, gt_label_text, (30, 50),
                cv2.FONT_HERSHEY_SIMPLEX, 1.2,
                (255, 0, 0), 3
            )
            cv2.putText(
                f, last_pred_label_text, (30, 100),
                cv2.FONT_HERSHEY_SIMPLEX, 1.2,
                (0, 255, 0), 3
            )
            out_vid.write(f)

        # Reset buffers for next sequence
        roi_seq.clear()
        frame_buffer.clear()

# After video ends, write any leftover frames using the last prediction
for f in frame_buffer:
    cv2.putText(
        f, gt_label_text, (30, 50),
        cv2.FONT_HERSHEY_SIMPLEX, 1.2,
        (255, 0, 0), 3
    )
    cv2.putText(
        f, last_pred_label_text, (30, 100),
        cv2.FONT_HERSHEY_SIMPLEX, 1.2,
        (0, 255, 0), 3
    )
    out_vid.write(f)

# Cleanup
cap.release()
out_vid.release()
print("✅ Output video saved as 'output_with_preds.mp4'")
