In [1]:
from torchvision.io.video import read_video
from torchvision.models.video import mvit_v2_s, MViT_V2_S_Weights

vid, _, _ = read_video("dataset/val/wrestling/5sr0Wgmn7BU_000085_000095.mp4", output_format="TCHW")
vid = vid[:16]  # optionally shorten duration

# Step 1: Initialize model with the best available weights
weights = MViT_V2_S_Weights.DEFAULT
model = mvit_v2_s(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(vid).unsqueeze(0)

Downloading: "https://download.pytorch.org/models/mvit_v2_s-ae3be167.pth" to /home/yungshun/.cache/torch/hub/checkpoints/mvit_v2_s-ae3be167.pth
100%|████████████████████████████████████████| 132M/132M [00:08<00:00, 16.6MB/s]


In [2]:
batch.shape

torch.Size([1, 3, 16, 224, 224])

In [3]:
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
prediction

tensor([1.8366e-04, 1.6695e-04, 1.0879e-04, 6.5832e-05, 1.6735e-04, 3.3791e-04,
        1.4666e-04, 1.5181e-04, 1.3985e-04, 1.3458e-04, 9.4528e-05, 2.1959e-04,
        1.1031e-04, 1.4101e-04, 1.9015e-04, 1.0865e-04, 1.5979e-04, 1.2496e-04,
        2.0478e-04, 1.6411e-04, 1.5531e-04, 1.4839e-04, 1.6432e-04, 1.2638e-04,
        1.9523e-04, 1.5940e-04, 1.1271e-04, 1.2722e-04, 2.7018e-04, 2.1047e-04,
        2.6080e-04, 1.5833e-04, 9.6968e-05, 1.2355e-04, 2.6626e-04, 1.0022e-04,
        1.8172e-04, 1.4483e-04, 1.3130e-04, 8.0548e-05, 2.3233e-04, 2.2879e-04,
        1.8449e-04, 3.3502e-04, 2.2491e-04, 1.0330e-04, 1.0267e-04, 1.5010e-04,
        1.9735e-04, 3.2568e-04, 2.8887e-04, 1.7221e-04, 1.2886e-04, 1.5380e-04,
        1.1498e-04, 6.4975e-04, 1.3146e-04, 1.0020e-04, 1.4592e-04, 2.0867e-04,
        1.9240e-04, 1.1163e-04, 1.3553e-04, 2.1101e-04, 1.4029e-04, 1.0014e-04,
        1.2247e-04, 1.1590e-04, 1.5785e-04, 2.0098e-04, 1.5987e-04, 1.6154e-04,
        8.5722e-05, 1.1510e-04, 1.1226e-

In [4]:
label = prediction.argmax().item()
label

395

In [5]:
score = prediction[label].item()
score

0.9139450192451477

In [6]:
category_name = weights.meta["categories"][label]
category_name

'wrestling'

In [7]:
print(f"{category_name}: {100 * score}%")

wrestling: 91.39450192451477%


In [10]:
print(weights.meta["categories"])

['abseiling', 'air drumming', 'answering questions', 'applauding', 'applying cream', 'archery', 'arm wrestling', 'arranging flowers', 'assembling computer', 'auctioning', 'baby waking up', 'baking cookies', 'balloon blowing', 'bandaging', 'barbequing', 'bartending', 'beatboxing', 'bee keeping', 'belly dancing', 'bench pressing', 'bending back', 'bending metal', 'biking through snow', 'blasting sand', 'blowing glass', 'blowing leaves', 'blowing nose', 'blowing out candles', 'bobsledding', 'bookbinding', 'bouncing on trampoline', 'bowling', 'braiding hair', 'breading or breadcrumbing', 'breakdancing', 'brush painting', 'brushing hair', 'brushing teeth', 'building cabinet', 'building shed', 'bungee jumping', 'busking', 'canoeing or kayaking', 'capoeira', 'carrying baby', 'cartwheeling', 'carving pumpkin', 'catching fish', 'catching or throwing baseball', 'catching or throwing frisbee', 'catching or throwing softball', 'celebrating', 'changing oil', 'changing wheel', 'checking tires', 'che

In [12]:
violence = [
    "archery", 
    "arm wrestling", 
    "bending metal", 
    "blasting sand", 
    "capoeira", 
    "catching or throwing baseball", 
    "catching or throwing frisbee", 
    "catching or throwing softball",
    "chopping wood",
    "cracking neck",
    "cutting pineapple", 
    "cutting watermelon",
    "drop kicking",
    "faceplanting",
    "hammer throw",
    "headbutting",
    "high kick",
    "hitting baseball",
    "javelin throw",
    "kicking field goal",
    "kicking soccer ball",
    "pumping fist",
    "punching bag",
    "punching person (boxing)",
    "slapping",
    "sword fighting",
    "tai chi",
    "throwing axe",
    "throwing ball",
    "wrestling"
]

In [14]:
violence_idx = [i for i, x in enumerate(weights.meta["categories"]) if x in violence]
violence_idx

[5,
 6,
 21,
 23,
 43,
 48,
 49,
 50,
 56,
 76,
 82,
 83,
 105,
 122,
 148,
 150,
 152,
 153,
 166,
 174,
 175,
 256,
 258,
 259,
 314,
 345,
 346,
 356,
 357,
 395]

In [9]:
weights.meta["categories"][395]

'wrestling'

In [16]:
def violence_recognition(violence_idx, prediction):
    label = prediction.argmax().item()
    score = prediction[label].item()
    if label in violence_idx:
        print(f"Violence: {100 * score}%")
    else:
        print(f"Normal: {100 * score}%")
        
violence_recognition(violence_idx, prediction)

Violence: 91.39450192451477%
