In [2]:
import numpy as np
import pandas as pd
import torch

from models.mstcn import MultiStageModel

ckpt_path = 'logs/250314-131556_TeCNO_Cholec80_mstcn_MultiStageModel/checkpoints/epoch=15-val_acc=0.27.ckpt'


class HParams:
    def __init__(self, mstcn_stages, mstcn_layers, mstcn_f_maps, mstcn_f_dim, out_features, mstcn_causal_conv):
        self.mstcn_stages = mstcn_stages
        self.mstcn_layers = mstcn_layers
        self.mstcn_f_maps = mstcn_f_maps
        self.mstcn_f_dim = mstcn_f_dim
        self.out_features = out_features
        self.mstcn_causal_conv = mstcn_causal_conv


def load_video(path):
    unpickled_x = pd.read_pickle(path)
    stem = np.asarray(unpickled_x[0],
                      dtype=np.float32)[::25]
    y_hat = np.asarray(unpickled_x[1],
                       dtype=np.float32)[::25]
    y = np.asarray(unpickled_x[2])[::25]
    return stem, y_hat, y

video_index = 55
video_path = f'logs/250119-122819_FeatureExtraction_Cholec80FeatureExtract_cnn_OneHeadResNet50Model/cholec80_pickle_export/1.0fps/video_{video_index}_25.0fps.pkl'

def main():
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    print(f"Using device: {device}")
    model = MultiStageModel(HParams(
        mstcn_stages=2,
        mstcn_layers=8,
        mstcn_f_maps=32,
        mstcn_f_dim=2048,
        out_features=13,
        mstcn_causal_conv=True
    ))
    checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
    new_state_dict = {}
    for key, values in checkpoint['state_dict'].items():
        if key.startswith('model.'):
            new_key = key.replace('model.', '')
        else:
            new_key = key
        new_state_dict[new_key] = values
    for unwanted_key in ["ce_loss.weight"]:
        if unwanted_key in new_state_dict:
            del new_state_dict[unwanted_key]
    model.load_state_dict(new_state_dict, strict=True)
    model.to(device)
    model.eval()
    stem, _, y = load_video(video_path)
    stem = torch.tensor(stem).to(device)
    with torch.no_grad():
        steam = stem.transpose(2, 1)
        out_stem = model(stem)
    phases = torch.softmax(out_stem, dim=2)
    print(y)
    print(phases)
    
main()

Using device: mps
num_stages_classification: 2, num_layers: 8, num_f_maps: 32, dim: 2048


KeyboardInterrupt: 