In [1]:
import logging
from datetime import datetime
import torch
import numpy as np
import pandas as pd
import random
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR

from model import VADModel
from ucf_test import test
# from utils.dataset import UCFDataset
from utils.tools import get_prompt_text, get_batch_label
import ucf_option

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# 配置变量
gt_path = './data/ucf_gt.json'
video_feature_list = 'list/ucf_CLIP_rgbtest.csv'
gt_txt = 'list/Temporal_Anomaly_Annotation.txt' # ucf的gt文件,介绍了动作区间

# 类别名称汇总
label_map = {
        'Normal': 'normal',
        'Abuse': 'abuse',
        'Arrest': 'arrest',
        'Arson': 'arson',
        'Assault': 'assault',
        'Burglary': 'burglary',
        'Explosion': 'explosion',
        'Fighting': 'fighting',
        'RoadAccidents': 'roadAccidents',
        'Robbery': 'robbery',
        'Shooting': 'shooting',
        'Shoplifting': 'shoplifting',
        'Stealing': 'stealing',
        'Vandalism': 'vandalism'
    }

batch_size = 64

visual_length = 256   # 每个视频特征的时序长度

train_list = 'list/ucf_CLIP_rgb_yanai.csv'  # 训练集列表，包含了地址与标签
test_list = 'list/ucf_CLIP_rgbtest_yanai.csv'   # 测试集列表，包含了地址与标签




In [4]:
# 读取 gt
gt = np.load("list/gt_ucf.npy")  # 每个特征文件的时序维度 * 16 * 文件数（290）

# 读取 gt_segment 和 gt_label,这两个用于mAP计算
gtsegments = np.load('list/gt_segment_ucf.npy', allow_pickle=True)  # 仅表示每个文件中标签所在的时序区间，比如异常动作在哪一段中，如果是正常，那就是全部

gtlabels = np.load('list/gt_label_ucf.npy', allow_pickle=True) # 每个文件的类别名称，如果是正常就是A



In [None]:
import torch.utils.data as data

def get_prompt_text(label_map: dict):
    # 获取所有类别名称，作为prompt文本输入
    prompt_text = []
    for v in label_map.values():
        prompt_text.append(v)

    return prompt_text


def process_feat(feat, length, is_random=False):
    # 对特征进行截取或填充
    clip_length = feat.shape[0]
    if feat.shape[0] > length:
        if is_random:
            return random_extract(feat, length), length
        else:
            return uniform_extract(feat, length), length
    else:
        return pad(feat, length), clip_length
    

def process_split(feat, length):
    # 对特征进行分割，在测试时使用
    clip_length = feat.shape[0]
    if clip_length < length:
        return pad(feat, length), clip_length
    else:
        split_num = int(clip_length / length) + 1
        for i in range(split_num):
            if i == 0:
                split_feat = feat[i*length:i*length+length, :].reshape(1, length, feat.shape[1])
            elif i < split_num - 1:
                split_feat = np.concatenate([split_feat, feat[i*length:i*length+length, :].reshape(1, length, feat.shape[1])], axis=0)
            else:
                split_feat = np.concatenate([split_feat, pad(feat[i*length:i*length+length, :], length).reshape(1, length, feat.shape[1])], axis=0)

        return split_feat, clip_length
    

class UCFDataset(data.Dataset):
    # UCF数据集读取类
    def __init__(self, clip_dim: int, file_path: str, test_mode: bool, label_map: dict, normal: bool = False):
        self.df = pd.read_csv(file_path)
        self.clip_dim = clip_dim
        self.test_mode = test_mode
        self.label_map = label_map
        self.normal = normal
        if normal == True and test_mode == False:
            self.df = self.df.loc[self.df['label'] == 'Normal']
            self.df = self.df.reset_index()
        elif test_mode == False:
            self.df = self.df.loc[self.df['label'] != 'Normal']
            self.df = self.df.reset_index()
        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        clip_feature = np.load(self.df.loc[index]['path'])
        if self.test_mode == False:
            clip_feature, clip_length = process_feat(clip_feature, self.clip_dim)
        else:
            clip_feature, clip_length = process_split(clip_feature, self.clip_dim)

        clip_feature = torch.tensor(clip_feature)
        clip_label = self.df.loc[index]['label']
        return clip_feature, clip_label, clip_length

In [7]:
# 训练集正常部分
normal_dataset = UCFDataset(clip_dim=visual_length, file_path=train_list, test_mode=False, label_map=label_map, normal=True)
normal_loader = DataLoader(normal_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

# 训练集异常部分
anomaly_dataset = UCFDataset(clip_dim=visual_length, file_path=train_list, test_mode=False, label_map=label_map, normal=False)
anomaly_loader = DataLoader(anomaly_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

# 测试集
test_dataset = UCFDataset(clip_dim=visual_length, file_path=test_list, test_mode=True, label_map=label_map)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# len(normal_dataset): 8000, len(anomaly_dataset): 8100, len(test_dataset): 290


In [None]:
# 构建模型
cls_num = len(label_map)  # 类别数量
embed_dim = 512  # 特征维度
visual_width = 512 # 每个时间步的视觉特征维度
visual_head = 1  # 视觉特征的头数
prompt_prefix = 10  # prompt增强维度
device = "cuda:0"

# 加载模型，虽然模型的构建也有些问题，但是先不管！
model = VADModel(
    cls_num,
    embed_dim,
    visual_length,
    visual_width,
    prompt_prefix,
    device
)

model



VADModel(
  (temporal_attn1): MTCAttentionBlock(
    (attn): MultiScaleCompressedAttention(
      (to_qkv): Linear(in_features=512, out_features=1536, bias=False)
      (split_heads): Rearrange('b l (h d) -> b h l d', h=8)
      (merge_heads): Rearrange('b h l d -> b l (h d)')
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (comp_blk1): CompressionBlock(
        (k_comp_mlp): Sequential(
          (0): Linear(in_features=1024, out_features=32, bias=True)
          (1): QuickGELU()
          (2): Linear(in_features=32, out_features=32, bias=True)
        )
        (v_comp_mlp): Sequential(
          (0): Linear(in_features=1024, out_features=32, bias=True)
          (1): QuickGELU()
          (2): Linear(in_features=32, out_features=32, bias=True)
        )
      )
      (comp_blk2): CompressionBlock(
        (k_comp_mlp): Sequential(
          (0): Linear(in_features=256, out_features=32, bias=True)
          (1): QuickGELU()
          (2): Linear(in_features=