In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset,Subset
from torch.utils.data import random_split
import torch.nn.init as init
import torch.nn.functional as F

import json
import h5py
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
# from adopt import ADOPT
from sklearn.model_selection import train_test_split
import math

from einops import rearrange
from einops.layers.torch import Rearrange
import random


In [2]:
def set_seed(seed):
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True # only applies to CUDA convolutional operation.
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.use_deterministic_algorithms(True)

In [3]:
class TemporalConv1D(nn.Module):
    """Simple temporal conv block for feature alignment"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,dp=0.0):
        super(TemporalConv1D, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
                               padding=padding)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dp)
        self.bn = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.bn(x)
        return x

In [4]:
######## cross attention #######

class CrossAttention(nn.Module):
    """Cross-attention between two feature streams"""
    def __init__(self, embed_dim, num_heads=16,dp=0.0):
        super(CrossAttention, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads,
                                                    dropout=dp, batch_first=True)

    def forward(self, query, key, value):
        attn_output, _ = self.multihead_attn(query, key, value)
        return attn_output

In [5]:
######## MHSA ########
class MHSA(nn.Module):
    """Multi-Head Self Attention"""
    def __init__(self, embed_dim, num_heads=16,dp=0.0):
        super(MHSA, self).__init__()
        self.mhsa = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads,dropout=dp, 
                                          batch_first=True)

    def forward(self, x):
        attn_output, _ = self.mhsa(x, x, x)
        return attn_output

In [None]:
######## Fusion model v4 ########

class FusionModel(nn.Module):

    def __init__(self, swin_feat_dim=(50, 1024)):
        
        super(FusionModel, self).__init__()

        
        # Apply custom weight initialization
        self.apply(self.kaiming_initialize)
    
    def kaiming_initialize(self,layer):
        """Apply Kaiming initialization for all relevant layers, including MultiHeadAttention."""
        if isinstance(layer, nn.Conv1d):
            nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
            if layer.bias is not None:
                nn.init.constant_(layer.bias, 0)
        elif isinstance(layer, nn.Linear):
            nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
            if layer.bias is not None:
                nn.init.constant_(layer.bias, 0)
        elif isinstance(layer, nn.MultiheadAttention):
            # Initialize in_proj_weight and in_proj_bias
            nn.init.kaiming_normal_(layer.in_proj_weight, mode='fan_in', nonlinearity='relu')
            if layer.in_proj_bias is not None:
                nn.init.constant_(layer.in_proj_bias, 0)
            # Initialize out_proj weight and bias
            nn.init.kaiming_normal_(layer.out_proj.weight, mode='fan_in', nonlinearity='relu')
            if layer.out_proj.bias is not None:
                nn.init.constant_(layer.out_proj.bias, 0)

    

    def forward(self, swin_features ):



        
        return output

In [None]:
def load_data(swin_file_path,labels_file_path):
     
    swin_features = np.load(swin_file_path)
    labels = np.load(labels_file_path)

    swin_features = torch.from_numpy( swin_features).float()
    labels = torch.from_numpy( labels ).long()
    
    return swin_features,labels