### 数据导入


In [None]:
!nvidia-smi
!pip list
!uv pip install mamba-ssm --no-build-isolation -q

In [None]:
# 使用gdown下载数据
import gdown

file_id = '1SwOr5V-eUaES_aWfdO6C1ykAuja08ipf'
gdown.download(f'https://drive.google.com/uc?id={file_id}', 'Raw.zip', quiet=False)

file_id = '1cPwMp98Wwew-YIaFjD7lWUvg9-V9f_WK'
gdown.download(f'https://drive.google.com/uc?id={file_id}', 'MOSI-label.csv', quiet=False)

file_id = '1MFzHi-g3wNzQ3dbDMr_0wtAkrcPIaVvI'
gdown.download(f'https://drive.google.com/uc?id={file_id}', 'aligned_50.pkl', quiet=False)

In [None]:
!cp /content/柒肆零/*.npy ./

In [None]:
!unzip /content/Raw.zip

In [None]:
import pickle
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import IPython.display as ipd
import matplotlib.pyplot as plt
import librosa.display

import seaborn as sns

In [None]:
pickle_filename = '/content/aligned_50.pkl'
csv_filename = '/content/MOSI-label.csv'
n_class = 3

### 提取图像文本语音特征

In [None]:
import torch
from transformers import AutoTokenizer, CLIPTextModel

def extract_clip_token_embeddings(model, tokenizer, texts, max_length=20, device="cpu"):

    if isinstance(texts, str):
        texts = [texts]

    # 编码
    inputs = tokenizer(
        texts,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=max_length
    ).to(device)

    model = model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        hidden_states = outputs.last_hidden_state  # shape: (B, seq_len, 512)

    return hidden_states[0].cpu().numpy()
from transformers import AutoTokenizer, CLIPTextModel

# 加载模型
t_clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
t_clip_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

texts = "a photo of a cat"
hidden_states = extract_clip_token_embeddings(t_clip_model, t_clip_tokenizer, texts, device=device)

print("Token embeddings shape:", hidden_states.shape)


In [None]:
from PIL import Image
import requests
from transformers import AutoProcessor, CLIPVisionModelWithProjection

img_clip_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
img_clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = img_clip_processor(images=image, return_tensors="pt")

outputs = img_clip_model(**inputs)
image_embeds = outputs.image_embeds

In [None]:
import cv2
import numpy as np
from PIL import Image
import torch
from transformers import CLIPVisionModelWithProjection, AutoProcessor

def extract_clip_features_from_video(model, processor, video_path, num_frames=5, device="cpu"):

    model.to(device)
    model.eval()

    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    if frame_count == 0:
        cap.release()
        raise ValueError("无法读取视频或视频为空")

    frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
    embeddings = []

    for idx in frame_indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
        ret, frame = cap.read()
        if not ret:
            print(f"警告：无法读取第 {idx} 帧，跳过")
            continue

        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(frame_rgb)

        # 图像预处理和特征提取
        inputs = processor(images=pil_image, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        embedding = outputs.image_embeds.squeeze(0).cpu().numpy()
        embeddings.append(embedding)

    cap.release()

    if len(embeddings) == 0:
        raise ValueError("未成功提取任何帧的特征")

    return np.stack(embeddings)


clip_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch16")
v_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch16")
video_file ='/content/Video/Segmented/03bSnISJMiM_10.mp4'
features = extract_clip_features_from_video(clip_model, v_processor, video_file, num_frames=5, device=device)

print("提取特征 shape:", features.shape)


### 提取音频特征

In [2]:
import librosa
import numpy as np

def extract_fixed_mfcc(wav_path, sr=16000, duration=1.0, n_mfcc=13, n_fft=400, hop_length=160):

    target_length = int(sr * duration)

    # 加载音频
    y, _ = librosa.load(wav_path, sr=sr)

    if len(y) < target_length:
        y = np.pad(y, (0, target_length - len(y)))
    else:
        y = y[:target_length]

    # 提取 MFCC 特征
    mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)

    return mfcc[:,:100]


## 合并处理文本 图像 音频

In [None]:
 import tqdm

 video_dir = '/content/Video/Segmented/'
 audio_dir = '/content/Audio/WAV_16000/Segmented/'
 vision_feats = []
 audio_feats = []
 text_feats = []
 for idx,row in tqdm.tqdm(pd_data.iterrows(),total = len(pd_data)):
   video_file = video_dir+row['video_file']
   audio_file = audio_dir+row['audio_file']
   one_text = row['text']
   one_vision_feat = extract_clip_features_from_video(clip_model, v_processor, video_file,5,device)
   one_audio_feat = extract_fixed_mfcc(audio_file)
   one_text_feat = extract_clip_token_embeddings(t_clip_model, t_clip_tokenizer, one_text, device=device)
   vision_feats.append(one_vision_feat)
   audio_feats.append(one_audio_feat)
   text_feats.append(one_text_feat)
    break
 # 转换为 numpy 数组
 vision_feats = np.stack(vision_feats)
 audio_feats = np.stack(audio_feats)
 text_feats = np.stack(text_feats)

 print("视觉特征 shape:", vision_feats.shape)
 print("音频特征 shape:", audio_feats.shape)
 print("文本特征 shape:", text_feats.shape)

 np.save("vision_feats.npy", vision_feats)
 np.save("audio_feats.npy", audio_feats)
 np.save("text_feats.npy", text_feats)

 # 加载npy文件为变量
 vision_feats = np.load("vision_feats.npy")
 audio_feats = np.load("audio_feats.npy")
 text_feats = np.load("text_feats.npy")
 print("视觉特征 shape:", vision_feats.shape)
 print("音频特征 shape:", audio_feats.shape)
 print("文本特征 shape:", text_feats.shape)

### 数据处理

In [None]:
class CustomDataset(Dataset):
    def __init__(self, vision_feats,audio_feats,text_feats,pd_data):
      self.vision_data = vision_feats
      self.audio_data = audio_feats
      self.text_data = text_feats
      self.labels = pd_data['label_num'].values

    def __len__(self):

        return len(self.vision_data)

    def __getitem__(self, idx):

        # 图像数据处理
        one_img = self.vision_data[idx]
        one_img = torch.tensor(one_img).float()

        # 文本数据处理
        one_text = self.text_data[idx]
        one_text = torch.tensor(one_text).float()

        # 语音数据处理
        one_audio = self.audio_data[idx]
        one_audio = torch.tensor(one_audio).float()


        # label数据处理
        one_label = self.labels[idx]
        one_label = torch.tensor(one_label).long()
        return one_img, one_text, one_audio, one_label


all_ds = CustomDataset(vision_feats,audio_feats,text_feats,pd_data)
train_ds, test_ds = torch.utils.data.random_split(all_ds, [0.8, 0.2])
train_ds[1][0].shape,train_ds[1][1].shape,train_ds[1][2].shape,len(train_ds),len(test_ds)

In [None]:
test_ds_1, test_ds_2 = torch.utils.data.random_split(test_ds, [0.3, 0.7])
len(test_ds_1),len(test_ds_2)

# 合并数据集
train_ds = torch.utils.data.ConcatDataset([train_ds, test_ds_2])
len(train_ds)

train_dl = DataLoader(train_ds, batch_size=8, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=8)

### 模型创新

In [None]:

import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from einops import rearrange, repeat

from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn

try:
    from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
    causal_conv1d_fn, causal_conv1d_update = None, None

try:
    from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
    selective_state_update = None

try:
    from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None


class Mamba(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=16,
        d_conv=4,
        expand=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        conv_bias=True,
        bias=False,
        use_fast_path=True,
        layer_idx=None,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
        self.use_fast_path = use_fast_path
        self.layer_idx = layer_idx

        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)

        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,
            **factory_kwargs,
        )

        self.activation = "silu"
        self.act = nn.SiLU()

        self.x_proj = nn.Linear(
            self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
        )
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)

        dt_init_std = self.dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        dt = torch.exp(
            torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        self.dt_proj.bias._no_reinit = True

        A = repeat(
            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=self.d_inner,
        ).contiguous()
        A_log = torch.log(A)
        self.A_log = nn.Parameter(A_log)
        self.A_log._no_weight_decay = True
        self.D = nn.Parameter(torch.ones(self.d_inner, device=device))
        self.D._no_weight_decay = True

        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)

    def forward(self, hidden_states, inference_params=None):

        batch, seqlen, dim = hidden_states.shape

        conv_state, ssm_state = None, None
        if inference_params is not None:
            conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
            if inference_params.seqlen_offset > 0:
                out, _, _ = self.step(hidden_states, conv_state, ssm_state)
                return out
        xz = rearrange(
            self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
            "d (b l) -> b d l",
            l=seqlen,
        )
        if self.in_proj.bias is not None:
            xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")

        A = -torch.exp(self.A_log.float())
        if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None:
            out = mamba_inner_fn(
                xz,
                self.conv1d.weight,
                self.conv1d.bias,
                self.x_proj.weight,
                self.dt_proj.weight,
                self.out_proj.weight,
                self.out_proj.bias,
                A,
                None,
                None,
                self.D.float(),
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
            )
        else:
            x, z = xz.chunk(2, dim=1)
            if conv_state is not None:
                conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))
            if causal_conv1d_fn is None:
                x = self.act(self.conv1d(x)[..., :seqlen])
            else:
                assert self.activation in ["silu", "swish"]
                x = causal_conv1d_fn(
                    x=x,
                    weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
                    bias=self.conv1d.bias,
                    activation=self.activation,
                )
            x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))
            dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
            dt = self.dt_proj.weight @ dt.t()
            dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
            B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
            C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
            assert self.activation in ["silu", "swish"]
            y = selective_scan_fn(
                x,
                dt,
                A,
                B,
                C,
                self.D.float(),
                z=z,
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
                return_last_state=ssm_state is not None,
            )
            if ssm_state is not None:
                y, last_state = y
                ssm_state.copy_(last_state)
            y = rearrange(y, "b d l -> b l d")
            out = self.out_proj(y)
        return out

    def step(self, hidden_states, conv_state, ssm_state):
        dtype = hidden_states.dtype
        assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
        xz = self.in_proj(hidden_states.squeeze(1))
        x, z = xz.chunk(2, dim=-1)

        if causal_conv1d_update is None:
            conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
            conv_state[:, :, -1] = x
            x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)
            if self.conv1d.bias is not None:
                x = x + self.conv1d.bias
            x = self.act(x).to(dtype=dtype)
        else:
            x = causal_conv1d_update(
                x,
                conv_state,
                rearrange(self.conv1d.weight, "d 1 w -> d w"),
                self.conv1d.bias,
                self.activation,
            )

        x_db = self.x_proj(x)
        dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        dt = F.linear(dt, self.dt_proj.weight)
        A = -torch.exp(self.A_log.float())

        if selective_state_update is None:

            dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
            dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
            dB = torch.einsum("bd,bn->bdn", dt, B)
            ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
            y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
            y = y + self.D.to(dtype) * x
            y = y * self.act(z)  # (B D)
        else:
            y = selective_state_update(
                ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
            )

        out = self.out_proj(y)
        return out.unsqueeze(1), conv_state, ssm_state

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        device = self.out_proj.weight.device
        conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
        conv_state = torch.zeros(
            batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
        )
        ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype

        ssm_state = torch.zeros(
            batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
        )
        return conv_state, ssm_state

    def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
        assert self.layer_idx is not None
        if self.layer_idx not in inference_params.key_value_memory_dict:
            batch_shape = (batch_size,)
            conv_state = torch.zeros(
                batch_size,
                self.d_model * self.expand,
                self.d_conv,
                device=self.conv1d.weight.device,
                dtype=self.conv1d.weight.dtype,
            )
            ssm_state = torch.zeros(
                batch_size,
                self.d_model * self.expand,
                self.d_state,
                device=self.dt_proj.weight.device,
                dtype=self.dt_proj.weight.dtype,

            )
            inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
        else:
            conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]

            if initialize_states:
                conv_state.zero_()
                ssm_state.zero_()
        return conv_state, ssm_state


### 定义mamba模块参数

In [None]:

batch, length, dim = 2, 3, 100
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    d_model=dim,
    d_state=16,
    d_conv=4,
    expand=2,
).to("cuda")
y = model(x)
y.shape

### mamba注意力机制

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class MambaAttention(nn.Module):
    def __init__(self, d_model, dim_feedforward=2048, dropout=0.1):
        super(MambaAttention, self).__init__()
        self.self_attn = Mamba(d_model=d_model, d_state=16, d_conv=4,expand=2)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src):
        # Self Attention
        src2 = self.self_attn(src)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # Feedforward Network
        src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

mamba_att = MambaAttention(100).to(device)

src = torch.rand(10, 3, 100).to(device)
output = mamba_att(src)

print(output.shape)

### 整体模型构建

In [None]:
class bigModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.img_block = img_block
        self.text_block = text_block
        self.audio_block = audio_block
        self.relu = nn.ReLU()
        self.lin1 = nn.Linear(600, 50)
        self.lin2 = nn.Linear(50, n_class)
        # self.att = CBAM1D(in_channels=3, reduction=4, spatial_kernel=7)
        self.att = MambaAttention(d_model=100)
        self.cnn = nn.Conv1d(6, 6, 3,padding='same')
        self.ca1 = nn.MultiheadAttention(100,2)
        self.ca2 = nn.MultiheadAttention(100,2)
        self.ca3 = nn.MultiheadAttention(100,2)


    def forward(self, img,text,audio):
        x1 = self.img_block(img)
        x2 = self.text_block(text)
        x3 = self.audio_block(audio)
        # x = torch.stack([x1,x2,x3],dim=1)
        x4,_ = self.ca1(x1,x2,x2)
        x5,_ = self.ca2(x2,x3,x3)
        x6,_ = self.ca3(x3,x1,x1)
        x = torch.stack([x1,x2,x3,x4,x5,x6],dim=1)
        # print(x.shape)
        x = self.att(x)
        x = self.cnn(x)
        x = self.flatten(x)
        x = self.lin1(x)
        x = self.lin2(x)
        # x =
        return x,x1,x2,x3
big_model = bigModel().to(device)
big_model(b_img.to(device),b_text.to(device),b_audio.to(device))[0].shape

### 损失函数和优化器

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(big_model.parameters(), lr=1e-3)


### 对比学习函数

In [None]:
class CtrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):

        super(CtrastiveLoss, self).__init__()
        self.temperature = temperature
        self.cosine_similarity = nn.CosineSimilarity(dim=-1)

    def forward(self, features_i, features_j):

        features_i = F.normalize(features_i, dim=1)
        features_j = F.normalize(features_j, dim=1)

        similarity_matrix = torch.matmul(features_i, features_j.T)  # Shape: [N, N]


        similarity_matrix = similarity_matrix / self.temperature


        labels = torch.arange(features_i.size(0)).to(features_i.device)

        loss_i_to_j = F.cross_entropy(similarity_matrix, labels)
        loss_j_to_i = F.cross_entropy(similarity_matrix.T, labels)

        loss = (loss_i_to_j + loss_j_to_i) / 2
        return loss

ctrastive_loss = CtrastiveLoss()
fea1 = torch.randn(10, 100)
fea2 = torch.randn(10, 100)
ctrastive_loss(fea1, fea2)

### 模型训练和测试

In [None]:
best_acc = 0.0
best_model_path = 'best_model_weights.pth'
epochs = 20
train_loss_list = []
train_acc_list = []
test_loss_list = []
test_acc_list = []
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_cl(train_dl, big_model, loss_fn, optimizer)

    train_loss, train_correct = test_cl(train_dl, big_model, loss_fn)
    test_loss, test_correct = test_cl(test_dl, big_model, loss_fn)

    train_loss_list.append(train_loss)
    train_acc_list.append(train_correct)
    test_loss_list.append(test_loss)
    test_acc_list.append(test_correct)

    if test_correct > best_acc:
        best_acc = test_correct
        torch.save(big_model.state_dict(), best_model_path)
        print(f"New best accuracy of {best_acc}% achieved at epoch {t+1}! Model weights saved.")

big_model.load_state_dict(torch.load(best_model_path))
