In [None]:
#Creating labels for audio and visuals at a single time. (to be ran first for lables)

import numpy as np
import os
import re

class Create_Labels():
    def __init__(self, config):
        self.config = config
        self.pattern = r'^frame_(\d{6})_face_(\d{2})\.jpg$'

    def load_segment(self, file_path):
        return np.load(file_path)

    def create_continuous_labels(self, total_samples, category):
        label_value = 1 if category == "PTSD" else 0
        return np.full(total_samples, label_value)

    def save_continuous_labels(self, labels, video_folder):
        output_filename = f"{os.path.basename(video_folder)}_Continuous_Label.npy"
        output_path = os.path.join(video_folder, output_filename)
        np.save(output_path, labels)

    def get_category_from_path(self, file_path):
        parts = file_path.split(os.sep)
        return parts[-2]

    def get_sorted_npy_segment_files(self, video_folder):
        segment_files = [f for f in os.listdir(video_folder) if f.startswith("segment_") and f.endswith(".npy")]
        return sorted(segment_files, key=lambda x: int(re.findall(r'segment_(\d+)\.npy', x)[0]))


    def get_sorted_visual_segment_folders(self, video_folder):
        return sorted(
            [f for f in os.listdir(video_folder) if f.startswith("segment_") and os.path.isdir(os.path.join(video_folder, f))],
            key=lambda x: int(re.findall(r'segment_(\d+)', x)[0])
        )

    def process_faces_files(self, directory):
        face_files = []
        Faces_Folder =  os.path.join(directory, 'faces')
        for filename in os.listdir(Faces_Folder):
            match = re.match(self.pattern, filename)
            if match:
                frame_number = int(match.group(1))
                face_number = int(match.group(2))
                face_files.append({
                    'filename': filename,
                    'frame_number': frame_number,
                    'face_number': face_number
                })
        face_files.sort(key=lambda x: (x['frame_number'], x['face_number']))
        return face_files


    def audio_labels(self, video_folder):
        category = self.get_category_from_path(video_folder)
        segment_files = self.get_sorted_npy_segment_files(video_folder)
        
        total_samples = sum(self.load_segment(os.path.join(video_folder, f)).shape[0] for f in segment_files)
        print(f' damples for {total_samples}')
        continuous_labels = self.create_continuous_labels(total_samples, category)
        self.save_continuous_labels(continuous_labels, video_folder)

        current_index = 0
        for segment_file in segment_files:
            segment_path = os.path.join(video_folder, segment_file)
            segment_data = self.load_segment(segment_path)
            segment_labels = continuous_labels[current_index:current_index + segment_data.shape[0]]

            print(f"Segment: {segment_file}, Shape: {segment_data.shape}, Labels: {segment_labels}")
            current_index += segment_data.shape[0]
    
    def visual_labels(self, video_folder):
        category = self.get_category_from_path(video_folder)
        segment_folders = self.get_sorted_visual_segment_folders(video_folder)

        total_samples = 0
        all_face_files = []

        for segment_folder in segment_folders:
            segment_path = os.path.join(video_folder, segment_folder)
            all_face_files.extend(self.process_faces_files(segment_path))

        total_samples = len(all_face_files)

        continuous_labels = self.create_continuous_labels(total_samples, category)
        continuous_labels_path = self.save_continuous_labels(continuous_labels, video_folder)
        print(f"Total faces in the video are : {total_samples}")

        current_index = 0
        for segment_folder in segment_folders:
            segment_path = os.path.join(video_folder, segment_folder)
            segment_face_files = self.process_faces_files(segment_path)
            
            print(f"{segment_folder} has {len(segment_face_files)} faces")
            print("Frame\tFace\tLabel")
            print("-" * 20)

            for face_file in segment_face_files:
                face_label = continuous_labels[current_index]
                print(f"{face_file['frame_number']}\t{face_file['face_number']}\t{face_label}")
                current_index += 1


    def processs_2_modalities(self):
        for category in ["PTSD", "Non-PTSD"]:
            category_path = os.path.join(self.config['data_directory'], category)
            for video_folder in os.listdir(category_path):
                video_path = os.path.join(category_path, video_folder)
                if os.path.isdir(video_path):
                    print(f"\nProcessing video: {video_folder}")
                    if self.config['data_directory'].endswith('Audio'):
                        self.audio_labels(video_path)
                    elif self.config['data_directory'].endswith('Visual'):
                        self.visual_labels(video_path)

def main():
    base_directory = '/media/scratch/datasets/PTSD_Project_train_validation_test_split/Final_train_new'
    for modality in [ 'Audio', 'Visual']:
        config = {'data_directory': os.path.join(base_directory, modality)}
        print(f"\nProcessing {modality} data:")
        preprocessor = Create_Labels(config)
        preprocessor.processs_2_modalities()
main()

In [None]:
#Label loading and creating batches (in a sequential-manner) for Audio and Visual from Final_Train (2nd run)
import os
import re
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader, SequentialSampler
from PIL import Image
from torchvision import transforms

device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")

# Audio Modality Classes and Functions
class Load_Audio_Segments:
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.video_folders = sorted([f for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))])
        self.segment_data = []  
        self.video_boundaries = []  
        
        for video_folder in self.video_folders:
            folder_path = os.path.join(root_dir, video_folder)
            all_labels = torch.from_numpy(np.load(os.path.join(folder_path, f"{video_folder}_Continuous_Label.npy"))).float().to(device)
            
            segment_files = sorted(
                [f for f in os.listdir(folder_path) if f.startswith('segment_') and f.endswith('.npy')],
                key=lambda x: int(re.search(r'segment_(\d+)\.npy', x).group(1)))
            
            video_start = len(self.segment_data)
            current_pos = 0  
            
            for seg_file in segment_files:
                seg_path = os.path.join(folder_path, seg_file)
                audio_tensor = torch.from_numpy(np.load(seg_path)).float().to(device)
                num_frames = audio_tensor.shape[0]
                
                label_tensor = all_labels[current_pos:current_pos + num_frames]
                current_pos += num_frames
                
                self.segment_data.append((audio_tensor, label_tensor, seg_file, seg_path))
            self.video_boundaries.append((video_start, len(self.segment_data)))

class VideoSegmentDataset(Dataset):
    def __init__(self, segments):
        self.segments = segments

    def __len__(self):
        return len(self.segments)

    def __getitem__(self, idx):
        audio_tensor, label_tensor, _, _ = self.segments[idx]
        return audio_tensor.to(device), label_tensor.to(device)

def no_batch_collate_fn(batch):
    return batch[0]

def process_audio_modality(config):
    processed_data = {'PTSD': {'Audio': {}}, 'Non-PTSD': {'Audio': {}}}
    for category in ["PTSD", "Non-PTSD"]:
        category_path = os.path.join(config['data_directory'], category)
        if not os.path.exists(category_path):
            continue
        dataset = Load_Audio_Segments(category_path)
        for vid_idx, (video_start, video_end) in enumerate(dataset.video_boundaries):
            video_name = dataset.video_folders[vid_idx]
            video_segments = dataset.segment_data[video_start:video_end]
            video_dataset = VideoSegmentDataset(video_segments)
            dataloader = DataLoader(
                video_dataset,
                batch_size=1,
                shuffle=False,
                collate_fn=no_batch_collate_fn)
            segments = []
            for batch in dataloader:
                segments.append(batch)
            processed_data[category]['Audio'][video_name] = segments
    return processed_data

# Visual Modality Classes and Functions
class Load_Visual_Labels_batches(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.video_folders = sorted([f for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))])
        self.segments = []
        self.transform = transform
        self.segment_info = []
        self.label_files = {}
        self.label_counts = {}
        self.pattern = r'^frame_(\d{6})_face_(\d{2})\.jpg$'

        for video_folder in self.video_folders:
            folder_path = os.path.join(root_dir, video_folder)
            segment_folders = sorted([f for f in os.listdir(folder_path) if f.startswith('segment_')],
                                     key=lambda x: int(re.search(r'segment_(\d+)', x).group(1)))
            labels_path = os.path.join(folder_path, f"{video_folder}_Continuous_Label.npy")
            self.label_files[video_folder] = np.load(labels_path)
            self.label_counts[video_folder] = len(self.label_files[video_folder])

            current_start = 0
            for segment_folder in segment_folders:
                segment_path = os.path.join(folder_path, segment_folder, 'faces')
                face_files = self.process_files(segment_path)
                num_faces = len(face_files)
                if num_faces == 0:
                    print(f" {segment_folder} skipping as it has no faces")
                    continue
                self.segments.append((segment_path, face_files))
                self.segment_info.append((video_folder, segment_folder, len(self.segments) - 1, num_faces, current_start))
                current_start += num_faces

    def __len__(self):
        return len(self.segment_info)

    def process_files(self, directory):
        face_files = []
        for filename in os.listdir(directory):
            match = re.match(self.pattern, filename)
            if match:
                face_files.append(filename)
        return sorted(face_files)
    
    def __getitem__(self, idx):
        video_folder, segment_folder, segment_idx, num_faces, label_start = self.segment_info[idx]
        segment_path, face_files = self.segments[segment_idx]
        
        images = []
        labels = []
        for i, face_file in enumerate(face_files):
            face_image_path = os.path.join(segment_path, face_file)
            image = Image.open(face_image_path).convert('RGB')

            if self.transform:
                image = self.transform(image)
            images.append(image)
            label_index = label_start + i
            if label_index < self.label_counts[video_folder]:
                labels.append(self.label_files[video_folder][label_index])
            else:
                labels.append(self.label_files[video_folder][-1])           
        return torch.stack(images), torch.tensor(labels).float()


def process_visual_modality(config):
    processed_data = {'PTSD': {'Visual': {}}, 'Non-PTSD': {'Visual': {}}}
    for category in ["PTSD", "Non-PTSD"]:
        category_path = os.path.join(config['data_directory'], category)
        if not os.path.exists(category_path):
            continue
        dataset = Load_Visual_Labels_batches(category_path, transform=config['transform'])
        for video_folder in dataset.video_folders:
            video_segments = [seg for seg in dataset.segment_info if seg[0] == video_folder]
            segments = []
            for segment_info in video_segments:
                _, segment_folder, segment_idx, _, _ = segment_info
                _, face_files = dataset.segments[segment_idx]
                images, labels = dataset[segment_idx]
                #Display_images_with_labels(images, labels, face_files, segment_folder)
                segments.append((images, labels))
            
            processed_data[category]['Visual'][video_folder] = segments
    return processed_data

# Main Processing
base_directory = '/media/scratch/datasets/PTSD_Project_train_validation_test_split'
dataset_types = ['Final_train_new']
final_datasets = {dataset_type: {'PTSD': {'Audio': {}, 'Visual': {}}, 'Non-PTSD': {'Audio': {}, 'Visual': {}}} for dataset_type in dataset_types}

for dataset_type in dataset_types:
    dataset_dir = os.path.join(base_directory, dataset_type)
    
    # Process Audio
    audio_config = {'data_directory': os.path.join(dataset_dir, 'Audio')}
    audio_data = process_audio_modality(audio_config)
    
    # Process Visual
    visual_config = {
        'data_directory': os.path.join(dataset_dir, 'Visual'),
        'transform': transforms.Compose([transforms.Resize((48, 48)), transforms.ToTensor()])
    }
    visual_data = process_visual_modality(visual_config)
    
    # Merge data
    for category in ['PTSD', 'Non-PTSD']:
        if category in audio_data and 'Audio' in audio_data[category]:
            for video, segments in audio_data[category]['Audio'].items():
                final_datasets[dataset_type][category]['Audio'][video] = segments
        if category in visual_data and 'Visual' in visual_data[category]:
            for video, segments in visual_data[category]['Visual'].items():
                final_datasets[dataset_type][category]['Visual'][video] = segments

for dataset_type in dataset_types:
    for category in ['PTSD', 'Non-PTSD']:
        for modality in ['Audio', 'Visual']:
            for video in final_datasets[dataset_type][category][modality]:
                print(modality, video)
                segments = final_datasets[dataset_type][category][modality][video]
                for i in range(len(segments)):
                    print(i)
                    samples, labels = segments[i]
                    print('samples is ', (samples.shape))
                    labels_np = labels.cpu().numpy() if isinstance(labels, torch.Tensor) else labels
                    length = labels_np.shape[0]
                    print('length is ', length)
                    pad_value = 1.0 if category == "PTSD" else 0.0
                    if length < 500:
                        padded_labels = np.pad(labels_np, (0, 500 - length), mode='constant', constant_values=pad_value)
                    else:
                        padded_labels = labels_np[:500]
                    padded_labels = torch.tensor(padded_labels, dtype=torch.float32).to(device)
                    segments[i] = (samples.to(device), padded_labels)
            print("--------")
X_processed = {'video': {'PTSD': final_datasets['Final_train_new']['PTSD']['Visual'],'Non-PTSD': final_datasets['Final_train_new']['Non-PTSD']['Visual']},
    'logmel': {'PTSD': final_datasets['Final_train_new']['PTSD']['Audio'], 'Non-PTSD': final_datasets['Final_train_new']['Non-PTSD']['Audio']}}
for modality in ['video', 'logmel']:
    for category in ['PTSD', 'Non-PTSD']:
        print(f"\n {modality} {category}")
        video_dict = X_processed[modality][category]
        
        for video_name, segments in video_dict.items():
            print(f"\n  Video: {video_name}")
            print(f"  Number of segments: {len(segments)}")
            '''for seg_idx, (samples, labels) in enumerate(segments):
                print(f"\n    Segment {seg_idx + 1}:")
                print(f"      Samples shape: {samples.shape}")  
                print(f"      Labels shape: {labels.shape}")'''

In [None]:
import torch
import torch.nn as nn
import math
from torch.nn import Linear, BatchNorm1d, BatchNorm2d, Dropout, Sequential, Module

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
import torch.nn.functional as F

regressor1 = nn.Linear(160, 256).to(device)
bn1 = BatchNorm1d(256).to(device)
regressor2 = nn.Linear(256, 1).to(device)

class DenseCoAttn(nn.Module):
	def __init__(self, dim1, dim2, dropout):
		super(DenseCoAttn, self).__init__()
		dim = dim1 + dim2
		self.dropouts = nn.ModuleList([nn.Dropout(dropout) for _ in range(2)])
		self.query_linear = nn.Linear(dim, dim)

		self.key1_linear = nn.Linear(500, 500)
		self.key2_linear = nn.Linear(500, 500)
		self.value1_linear = nn.Linear(dim1, dim1)
		self.value2_linear = nn.Linear(dim2, dim2)
		self.relu = nn.ReLU()

	def forward(self, value1, value2):

		joint = torch.cat((value1, value2), dim=-1)
		# audio  audio*W*joint
		va_joint = self.query_linear(joint)
		key1 = self.key1_linear(value1.transpose(1, 2))
		key2 = self.key2_linear(value2.transpose(1, 2))

		value1 = self.value1_linear(value1)
		value2 = self.value2_linear(value2)
          
		weighted1, attn1 = self.qkv_attention(joint, key1, value1, dropout=self.dropouts[0])
		weighted2, attn2 = self.qkv_attention(joint, key2, value2, dropout=self.dropouts[1])
          
		return weighted1, weighted2
     
	def qkv_attention(self, query, key, value, dropout=None):
		d_k = query.size(-1)
		scores = torch.bmm(key, query) / math.sqrt(d_k)
		scores = torch.tanh(scores)
        
		if dropout:
			scores = dropout(scores)
		weighted = torch.tanh(torch.bmm(value, scores))
		return self.relu(weighted), scores

class NormalSubLayer(nn.Module):

    def __init__(self, dim1, dim2, dropout):
        super(NormalSubLayer, self).__init__()
        self.dense_coattn = DenseCoAttn(dim1, dim2, dropout)
        self.linears = nn.ModuleList([
            nn.Sequential(
                nn.Linear(dim1 + dim2, dim1),
                nn.ReLU(inplace=True),
                nn.Dropout(p=dropout),
            ),
            nn.Sequential(
                nn.Linear(dim1 + dim2, dim2),
                nn.ReLU(inplace=True),
                nn.Dropout(p=dropout))])

    def forward(self, data1, data2):
        weighted1, weighted2 = self.dense_coattn(data1, data2)
        data1 = data1 + self.linears[0](weighted1)
        data2 = data2 + self.linears[1](weighted2)
        return data1, data2

class DCNLayer(nn.Module):
    def __init__(self, dim1, dim2, num_seq, dropout):
        super(DCNLayer, self).__init__()
        self.dcn_layers = nn.ModuleList([NormalSubLayer(dim1, dim2, dropout) for _ in range(num_seq)])

    def forward(self, data1, data2):
        for dense_coattn in self.dcn_layers:
            data1, data2 = dense_coattn(data1, data2)
        return data1, data2

In [None]:
import math
import torch
import torch.nn as nn
import sys
import torch
from base.vggish.vggish import VGGish
from models.temporal_convolutional_model import TemporalConvNet
import math
import os
import torch
from torch import nn
import numpy as np
from torch.nn import Linear, BatchNorm1d, BatchNorm2d, Dropout, Sequential, Module
import torch.nn.functional as F
from models.arcface_model import Backbone
import tensorflow as tf
import torch

#working with the audio
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class VGG(nn.Module):
    def __init__(self, features):
        super(VGG, self).__init__()
        self.features = features
        self.embeddings = nn.Sequential(
            nn.Linear(512 * 4 * 6, 4096),
            nn.ReLU(True),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Linear(4096, 128))

    def forward(self, x):
        x = self.features(x)
        x = torch.transpose(x, 1, 3)
        x = torch.transpose(x, 1, 2)
        x = x.contiguous()
        x = x.view(x.size(0), -1)
        x = self.embeddings(x)
        return x

def make_layers():
    layers = []
    in_channels = 1  
    for v in [64, "M", 128, "M", 256, 256, "M", 512, 512, "M"]:
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

def _vgg():
    return VGG(make_layers())

class VGGish(VGG):
    def __init__(self):
        super().__init__(make_layers())

    def forward(self, x, fs=None):
        x = VGG.forward(self, x)
        return x
    
class AudioBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = VGGish()
        for param in self.backbone.parameters():
            param.requires_grad = False 
            
    def forward(self, x):
        x = self.backbone(x)
        return x

#working with the visual
class VisualBackbone(nn.Module):
    def __init__(self, input_channels=3, num_classes=8, use_pretrained=True, state_dict_path="/media/scratch/Trauma_Detection/code/RestNet50.pth", mode="ir",
                 embedding_dim=512):
        super().__init__()
        self.backbone = Backbone(input_channels=input_channels, num_layers=50, drop_ratio=0.4, mode=mode)
        if use_pretrained:
            state_dict = torch.load(state_dict_path, map_location=device)

            if "backbone" in list(state_dict.keys())[0]:

                self.backbone.output_layer = Sequential(BatchNorm2d(embedding_dim),
                                                        Dropout(0.4),
                                                        Flatten(),
                                                        Linear(embedding_dim * 5 * 5, embedding_dim),
                                                        BatchNorm1d(embedding_dim))
                new_state_dict = {}
                for key, value in state_dict.items():

                    if "logits" not in key:
                        new_key = key[9:]
                        new_state_dict[new_key] = value

                self.backbone.load_state_dict(new_state_dict)
            else:
                self.backbone.load_state_dict(state_dict)

            for param in self.backbone.parameters():
                param.requires_grad = False

        self.backbone.output_layer = Sequential(BatchNorm2d(embedding_dim),
                                                Dropout(0.4),
                                                Flatten(),
                                                Linear(embedding_dim * 5 * 5, embedding_dim),
                                                BatchNorm1d(embedding_dim))
        
        self.backbone.output_layer = Sequential(
        BatchNorm2d(embedding_dim),
        Dropout(0.4),
        nn.AdaptiveAvgPool2d((5, 5)),
        Flatten(),
        Linear(embedding_dim * 5 * 5, embedding_dim),
        BatchNorm1d(embedding_dim))

        self.logits = nn.Linear(in_features=embedding_dim, out_features=num_classes)

        from torch.nn.init import xavier_uniform_, constant_

        for m in self.backbone.output_layer.modules():
            if isinstance(m, nn.Linear):
                m.weight = xavier_uniform_(m.weight)
                m.bias = constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


        self.logits.weight = xavier_uniform_(self.logits.weight)
        self.logits.bias = constant_(self.logits.bias, 0)

    def forward(self, x):
        x = self.backbone(x)
        return x

    def extract(self, x):
        x = self.backbone(x)
        return x

class Joint_Module(nn.Module):
    def __init__(self, modality=['video', 'logmel'], kernel_size=5, example_length=500, tcn_attention=0,
                 tcn_channel={'video': [512, 256, 256, 128], 'cnn_res50': [512, 256, 256, 128], 'mfcc':[32, 32, 32, 32], 'vggish': [32, 32, 32, 32], 'logmel': [32, 32, 32, 32]},
                 embedding_dim={'video': 512,  'bert': 768, 'cnn_res50': 512, 'mfcc': 39, 'vggish': 128, 'logmel': 128, 'egemaps': 88},
                 encoder_dim={'video': 128, 'bert': 128, 'cnn_res50': 128, 'mfcc': 32, 'vggish': 32, 'logmel': 32, 'egemaps': 32},
                 modal_dim=32, num_heads=2,
                 root_dir='', device=device):
        super().__init__()
        self.examples_segment = segments
        self.root_dir = root_dir
        self.device = device
        self.modality = modality
        self.kernel_size = kernel_size
        self.example_length = example_length
        self.tcn_channel = tcn_channel
        self.tcn_attention = tcn_attention
        self.embedding_dim = embedding_dim
        self.encoder_dim = encoder_dim
        self.outputs = {}
        self.temporal, self.fusion = nn.ModuleDict(), None
        self.num_heads = num_heads
        self.modal_dim = modal_dim
        self.final_dim = self.embedding_dim[self.modality[0]] + self.embedding_dim[self.modality[1]] + self.embedding_dim[self.modality[0]]
        self.spatial = nn.ModuleDict()
        self.bn = nn.ModuleDict()
        self.Length = 500
        self.spatial_visual_model = VisualBackbone().to(device)
        self.spatial_audio_model = AudioBackbone().to(device)

    
    def init(self):
        self.output_dim = 1
        
        for modality in self.modality:
            self.temporal[modality] = TemporalConvNet(num_inputs=self.embedding_dim[modality], max_length=self.example_length,
                                                   num_channels=self.tcn_channel[modality], attention=self.tcn_attention,
                                                   kernel_size=self.kernel_size, dropout=0.1).to(self.device)
            self.bn[modality] = BatchNorm1d(self.tcn_channel[modality][-1])

        self.coattn = DCNLayer(128, 32, 2, 0.6)
        #self.regressor1 = nn.Linear(512, 256) --Initial
        self.regressor1 = nn.Linear(256, 256)
        self.bn1 = BatchNorm1d(256)
        self.regressor2 = nn.Linear(256, self.output_dim)

    def forward(self,X):
        Video, Audio = self.coattn(X['visual_features'][0],X['logmel_features'][0])
        c = torch.cat((Video, Audio), dim=-1)
        c = regressor1(c).transpose(1, 2)
        c = bn1(c).transpose(1, 2)
        c = F.leaky_relu(c)
        c = regressor2(c)
        c = torch.tanh(c)
        return c

    def forward_video(self, segment):
        batch_input = torch.tensor(segment, dtype=torch.float32).to(self.device)
        batch_input = batch_input.unsqueeze(0)
        batch_size, length, channel, width, height = batch_input.shape
        batch_input = batch_input.view(-1, channel, width, height)
        batch_input = batch_input.to(device) 
        batch_input = self.spatial_visual_model(batch_input)
        batch_input = batch_input.to(self.device)
        _, feature_dim = batch_input.shape
        video_features = batch_input.view(batch_size, length, feature_dim).unsqueeze(1)
        return video_features.squeeze(1).transpose(1, 2)  

    def forward_logmel(self, segment):  
        segment = torch.tensor(segment, dtype=torch.float32).to(self.device)
        batch_input = segment.unsqueeze(1).float().to(self.device)
        batch_size = batch_input.shape[1]
        length = batch_input.shape[0]
        vggish_features = self.spatial_audio_model(batch_input)
        #vggish_features = vggish_features.cpu().detach().numpy()
        _, feature_dim = vggish_features.shape
        logmel_features = torch.tensor(vggish_features).view(batch_size, length, feature_dim).unsqueeze(1)
        return logmel_features.squeeze(1).transpose(1, 2)

    def compact_feature(self, spatial_tensor):
        sequence_length = spatial_tensor.shape[2]
        length_difference = self.Length - sequence_length
        if length_difference > 0:
            last_slice = spatial_tensor[:, :, -1:]
            padding = last_slice.repeat(1, 1, length_difference)
            tensor = torch.cat((spatial_tensor, padding), dim=2)
        else:
            tensor = spatial_tensor
        return tensor

2025-04-11 15:43:21.240793: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744400601.248390 1446635 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744400601.250698 1446635 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744400601.257747 1446635 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744400601.257752 1446635 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744400601.257753 1446635 computation_placer.cc:177] computation placer alr

In [25]:
from operator import itemgetter
import numpy as np

class GenericParamControl(object):
    @staticmethod
    def init_module_list():
        raise NotImplementedError

    @staticmethod
    def init_param_group():
        raise NotImplementedError

    def get_param_group(self):
        raise NotImplementedError

    def release_param(self, model):
        raise NotImplementedError

class ResnetParamControl(GenericParamControl):
    def __init__(self, trainer, gradual_release=1, release_count=8, backbone_mode="ir"):
        self.trainer = trainer
        self.gradual_release = gradual_release
        self.release_count = release_count
        self.backbone_mode = backbone_mode
        self.module_dict = self.init_module_list()
        self.module_stack = self.init_param_group()
        self.early_stop = False

    def init_module_list(self):
        return {
            "spatial_visual_model": [
                [(4, 10)], 
                [(163, 187)], 
                [(142, 163)]
            ],
            "spatial_audio_model": [
                [(16, 18)], 
                [(14, 16)], 
                [(12, 14)]
            ]
        }

    def init_param_group(self):
        module_stack = {"spatial_visual_model": [], "spatial_audio_model": []}
        for modal, ranges in self.module_dict.items():
            for groups in ranges:
                slice_range = []
                for group in groups:
                    slice_range += list(np.arange(*group))
                module_stack[modal].append(slice_range)
        return module_stack

    def get_param_group(self, modal):
        return self.module_stack[modal].pop(0)

    def get_current_lr(self):
        return self.trainer.optimizer.param_groups[0]['lr']

    def release_param(self, model, epoch=0, modalities=['spatial_visual_model', 'spatial_audio_model']):
        if self.gradual_release and self.release_count > 0:
            for modal in modalities:
                indices = self.get_param_group(modal) 
                print(f"Releasing parameters for: {modal}")
                print(f"Releasing layers in range: {list(indices)}")
                modal_submodel = getattr(model, modal)  
                modal_params = list(modal_submodel.parameters())
                for param in itemgetter(*indices)(modal_params):
                    param.requires_grad = True
            self.trainer.init_optimizer_and_scheduler(epoch=epoch)
            self.release_count -= 1
            self.trainer.early_stopping_counter = self.trainer.early_stopping
        else:
            print("No more parameters to release!")
            self.early_stop = True

    def load_trainer(self, trainer):
        self.trainer = trainer

In [29]:
model = Joint_Module()
model.init()
model.to(device)
model.eval()

Spatial_visual_features, Temporal_visual_features = {}, {}
Spatial_logmel_features, Temporal_logmel_features = {}, {}

for modality, data in X_processed.items():
    print(f"{modality}\n-----")
    process_fn = model.forward_logmel if modality == 'logmel' else model.forward_video  

    for category, videos in data.items():
        if modality == 'logmel':
            Spatial_logmel_features.setdefault(category, {})
            Temporal_logmel_features.setdefault(category, {})
        else:
            Spatial_visual_features.setdefault( category, {})
            Temporal_visual_features.setdefault( category, {})
        
        for video_name, segments in videos.items():
            print(f"{video_name} with {len(segments)} segments of {category}")

            spatial_dict = Spatial_logmel_features if modality == 'logmel' else Spatial_visual_features
            temporal_dict = Temporal_logmel_features if modality == 'logmel' else Temporal_visual_features
            
            spatial_dict[category].setdefault(video_name, [])
            temporal_dict[category].setdefault(video_name, [])
            
            for i, (tensor, labels) in enumerate(segments[:3]):
                if i >= 3: 
                    break
                tensor = tensor.to(device)
                output = process_fn(tensor)
                spatial_features = output if isinstance(output, torch.Tensor) else output[0]
                spatial_features = model.compact_feature(spatial_features).to(device)
                spatial_dict[category][video_name].append([spatial_features, labels])
                temporal_features= model.bn[modality](model.temporal[modality](spatial_features.to(device))).transpose(1, 2).to(device)
                temporal_dict[category][video_name].append([temporal_features,labels])
torch.cuda.empty_cache()

  WeightNorm.apply(module, name, dim)
  batch_input = torch.tensor(segment, dtype=torch.float32).to(self.device)


video
-----
vidm031 with 33 segments of PTSD
Great_Day with 35 segments of Non-PTSD
logmel
-----
vidm031 with 33 segments of PTSD
Great_Day with 35 segments of Non-PTSD


  segment = torch.tensor(segment, dtype=torch.float32).to(self.device)
  logmel_features = torch.tensor(vggish_features).view(batch_size, length, feature_dim).unsqueeze(1)


In [30]:
from tqdm import tqdm
Y = {'logmel': {category: {video_name: Temporal_logmel_features[category][video_name] for video_name in Temporal_logmel_features[category]} for category in Temporal_logmel_features},
    'video': {category: {video_name: Temporal_visual_features[category][video_name] for video_name in Temporal_visual_features[category]} for category in Temporal_visual_features}}

def stacking_video_wise(temporal_dict):
    video_wise_data = [] 
    for category in temporal_dict:
        for video_name in temporal_dict[category]:
            features_list = []
            for segment in temporal_dict[category][video_name]:
                features_list.append(segment[0])
            if features_list:
                video_wise_data.append({
                    'video_name': video_name,
                    'category': category,
                    'features': torch.stack(features_list,dim=0)})
    return video_wise_data

temporal_visual_features = stacking_video_wise(Y['video'])
temporal_logmel_features = stacking_video_wise(Y['logmel'])

for visual_item, logmel_item in zip(temporal_visual_features, temporal_logmel_features):
    visual_item['features'] = visual_item['features'].squeeze(1)
    logmel_item['features'] = logmel_item['features'].squeeze(1)

def align_features(visual_data, logmel_data):
    aligned_data = []
    for vis_item, log_item in zip(visual_data, logmel_data):
        assert vis_item['video_name'] == log_item['video_name']
        num_segments = vis_item['features'].shape[0]
        print(num_segments)
        aligned_data.append({
            'video_name': vis_item['video_name'],
            'category': vis_item['category'],
            'visual_features': vis_item['features'],
            'logmel_features': log_item['features'],
            'label':  torch.full((num_segments, 500, 1), 
                                1 if vis_item['category'] == 'PTSD' else 0,
                                dtype=torch.float32),
            'segments': num_segments} )
    return aligned_data

aligned_data = align_features(temporal_visual_features, temporal_logmel_features)
def init_dataloader(data, batch_size=1, shuffle=False):
    dataloaders = DataLoader(data,batch_size=1,shuffle=shuffle)
    return dataloaders

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
model.to(device)

train_dataloader = init_dataloader(aligned_data, batch_size=1, shuffle=False)
dataloader_dict = {'train': train_dataloader,  'validate': train_dataloader }
num_batch_warm_up = len(train_dataloader)
for batch_idx, batch in tqdm(enumerate(dataloader_dict['train']), total=len(dataloader_dict['train'])):
    batch = {key: value.to(device) if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
    print(batch['video_name'], batch['category'], batch['label'].shape[1])
    with torch.no_grad():
        output = model(batch)
        print(output.shape)

3
3


100%|██████████| 2/2 [00:00<00:00, 473.64it/s]

['vidm031'] ['PTSD'] 3
torch.Size([3, 500, 1])
['Great_Day'] ['Non-PTSD'] 3
torch.Size([3, 500, 1])





In [None]:
from RJCMA.base.scheduler import  MyWarmupScheduler
from tqdm import tqdm
from torch import optim
import torch
import time
import copy
import os

class Trainer:
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def init_optimizer_and_scheduler(self, epoch=0):
        self.optimizer = optim.Adam(self.get_parameters(), lr=self.learning_rate, weight_decay=0.001)
        self.scheduler = MyWarmupScheduler(
            optimizer=self.optimizer, lr = self.learning_rate, min_lr=self.min_learning_rate,
            best=self.best_epoch_info['ccc'], mode="max", patience=self.patience,
            factor=self.factor, num_warmup_epoch=self.min_epoch, init_epoch=epoch)
            
    def Train(self, dataloader_dict, checkpoint_controller, parameter_controller):
        print("------\nStarting training on device:", self.device)
        self.time_fit_start = time.time()
        if self.best_epoch_info is None:
                self.best_epoch_info = {'model_weights': copy.deepcopy(self.model.state_dict()),'loss': 1e10,'ccc': -1e10}
        for epoch in range(self.start_epoch, self.max_epoch):
            if self.fit_finished:
                print("\nEarly Stop!\n")
                break
            
            if (epoch in self.milestone or (parameter_controller.get_current_lr() < self.min_learning_rate and epoch >= self.min_epoch and self.scheduler.relative_epoch > self.min_epoch)):
                parameter_controller.release_param(self.model.spatial, epoch)
                if parameter_controller.early_stop:
                    break
                self.model.load_state_dict(self.best_epoch_info['model_weights'])
            
            time_epoch_start = time.time()
            if self.verbose:
                print("There are {} layers to update.".format(len(self.optimizer.param_groups[0]['params'])))

            train_loss = self.train(dataloader_dict['train'], epoch =epoch)
            val_loss = self.validate(dataloader_dict['validate'], epoch=epoch)
            
            improvement = False
            if val_loss > self.best_epoch_info['loss']:
                torch.save(self.model.state_dict(), os.path.join(self.save_path, "model_state_dict"  + ".pth"))
                improvement = True
                self.best_epoch_info = {'model_weights': copy.deepcopy(self.model.state_dict()),'loss': val_loss,'epoch': epoch}

            if self.verbose:
                    print(
                        "\n Fold {:2} Epoch {:2} in {:.0f}s || Train loss={:.3f} | Val loss={:.3f} | LR={:.1e} | Release_count={} | best={} | "
                        "improvement={}-{}".format(
                            self.fold,
                            epoch + 1,
                            time.time() - time_epoch_start,
                            train_loss,
                            val_loss,
                            self.optimizer.param_groups[0]['lr'],
                            parameter_controller.release_count,
                            int(self.best_epoch_info['epoch']) + 1,
                            improvement,
                            self.early_stopping_counter))

            if self.early_stopping:
                if improvement:
                    self.early_stopping_counter = self.early_stopping
                else:
                    self.early_stopping_counter -= 1

                if self.early_stopping_counter <= 0 and epoch >= self.min_epoch:
                    self.fit_finished = True

            self.scheduler.step(metrics=val_loss, epoch=epoch)
            self.start_epoch = epoch + 1

            if self.load_best_at_each_epoch:
                    self.model.load_state_dict(self.best_epoch_info['model_weights'])
            checkpoint_controller.save_checkpoint(self, parameter_controller, self.save_path)
        self.fit_finished = True
        checkpoint_controller.save_checkpoint(self, parameter_controller, self.save_path)
        self.model.load_state_dict(self.best_epoch_info['model_weights'])

    def train(self, **kwargs):
        self.model.train()
        train_loss = self.fit(**kwargs)
        return train_loss

    def validate(self, **kwargs):
        with torch.no_grad():
            self.model.eval()
            val_loss = self.fit(**kwargs)
        return val_loss

    def fit(self, dataloader):
        running_loss = 0.0
        total_segments = 0  
        num_batch_warm_up = len(dataloader) * self.min_epoch
        for batch_idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
            self.scheduler.warmup_lr(self.learning_rate, batch_idx, num_batch_warm_up)
            segment_count = len(batch['segments'])
            total_segments += segment_count
            labels = batch.pop('label')
            self.optimizer.zero_grad()
            outputs = self.model(batch)
            loss = self.criterion(labels, outputs)
            running_loss += loss.mean().item() * segment_count 
            loss.backward()
            self.optimizer.step()

        epoch_loss = running_loss / total_segments 
        return epoch_loss


In [33]:
from RJCMA.base.loss_function import CCCLoss
from RJCMA.base.checkpointer import Checkpointer
import os

class Experiment:
    def __init__(self, args):
        super().__init__(args)
        self.args = args
        self.release_count = args.release_count
        self.gradual_release = args.gradual_release
        self.milestone = args.milestone
        self.backbone_mode = "ir"
        self.min_num_epochs = args.min_num_epochs
        self.num_epochs = args.num_epochs
        self.early_stopping = args.early_stopping
        self.load_best_at_each_epoch = args.load_best_at_each_epoch

        self.num_heads = args.num_heads
        self.modal_dim = args.modal_dim
        self.tcn_kernel_size = args.tcn_kernel_size

    def run(self):
        criterion = CCCLoss()
        for fold in iter(self.folds_to_run):

            save_path = os.path.join(self.save_path,
                                     self.experiment_name + "_" + self.model_name + "_" + self.stamp + "_fold" + str(
                                         fold) + "_" + self.emotion +  "_seed" + str(self.seed))
            os.makedirs(save_path, exist_ok=True)

            checkpoint_filename = os.path.join(save_path, "checkpoint.pkl")

            trainer_kwards = {'device': self.device, 'emotion': self.emotion, 'model_name': self.model_name,
                              'models': model, 'save_path': save_path, 'fold': fold,
                              'min_epoch': self.min_num_epochs, 'max_epoch': self.num_epochs,
                              'early_stopping': self.early_stopping, 'scheduler': self.scheduler,
                              'learning_rate': self.learning_rate, 'min_learning_rate': self.min_learning_rate,
                              'patience': self.patience, 'batch_size': self.batch_size,
                              'criterion': criterion, 'factor': self.factor, 'verbose': True,
                              'milestone': self.milestone, 'metrics': self.config['metrics'],
                              'load_best_at_each_epoch': self.load_best_at_each_epoch}

            trainer = Trainer(**trainer_kwards)

            parameter_controller = ResnetParamControl(trainer, gradual_release=self.gradual_release,
                                                      release_count=self.release_count,
                                                      backbone_mode=["spatial_visual_model", "spatial_audio_model"])

            checkpoint_controller = Checkpointer(checkpoint_filename, trainer, parameter_controller, resume=self.resume)

            if self.resume:
                trainer, parameter_controller = checkpoint_controller.load_checkpoint()

            if not trainer.fit_finished:
                trainer.train(dataloader_dict, parameter_controller=parameter_controller,
                            checkpoint_controller=checkpoint_controller)

    def get_selected_continuous_label_dim(self):
        if self.emotion == "PTSD":
            dim = [1]
        elif self.emotion == "Non-PTSD":
            dim = [0]
        else:
            raise ValueError("Unknown emotion!")
        return dim
