<a href="https://colab.research.google.com/github/saadyas/DMMS/blob/main/DMMS-modifications.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install rouge_score



In [None]:
import nltk
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import random
import torch
import h5py
import numpy as np
import json
import math
from tqdm import tqdm

class MSMODataset(object):
    def __init__(self, mode='train', args=None):
        self.gt = json.load(open('{}/{}/annotation/{}.json'.format(args.get('data_root'), args.get('dataset'), mode)))
        self.id_list = list(self.gt.keys())

        self.video_dict = np.load('{}/{}/feature/video_resnet50_{}.npy'.format(args.get('data_root'), args.get('dataset'), mode), allow_pickle=True).item()
        self.text_dict = np.load('{}/{}/feature/text_roberta_{}.npy'.format(args.get('data_root'), args.get('dataset'), mode), allow_pickle=True).item()
        if args.get('dataset') == 'Daily_Mail':
            self.video_summ_dict = np.load('{}/{}/feature/video_summ_resnet50_{}.npy'.format(args.get('data_root'), args.get('dataset'), mode), allow_pickle=True).item()
        else:
            self.video_summ_dict = {}

        for id in tqdm(self.id_list):
            self.video_dict[id] = torch.tensor(self.video_dict[id]).to(torch.float32)
            self.text_dict[id] = torch.tensor(self.text_dict[id]).to(torch.float32)

            if args.get('dataset') == 'Daily_Mail':
                self.video_summ_dict[id] = torch.tensor(self.video_summ_dict[id]).to(torch.float32)
            else:
                self.video_summ_dict[id] = torch.zeros(1).to(torch.float32)
        self.dataset = args.get('dataset')

    def __len__(self):
        return len(self.id_list)

    def __getitem__(self, index):
        id = self.id_list[index]

        video = self.video_dict[id] # [T, 2048]
        video_summ = self.video_summ_dict[id]
        text = self.text_dict[id] # [N, 768]
        #print('\nID : {}'.format(id))
        num_frame = video.shape[0]
        num_keyframe = video_summ.shape[0]
        num_sentence = text.shape[0]

        if self.dataset == 'Daily_Mail':
            video_label = torch.tensor(self.gt[id]['video_label'], dtype=torch.long)
            assert torch.sum(video_label) == num_keyframe
        else:
            video_label = torch.zeros(num_frame).to(torch.long)
            video_label[0] = 1
        text_label = torch.tensor(self.gt[id]['text_label'], dtype=torch.long)

        article_sentence = self.gt[id]['article_sentence']
        highlight = self.gt[id]['highlight']

        mask_video = torch.ones(num_frame, dtype=torch.long)
        mask_video_summ = torch.ones(num_keyframe, dtype=torch.long)
        mask_text = torch.ones(num_sentence, dtype=torch.long)

        video_to_text_mask = torch.zeros(1)
        text_to_video_mask = torch.zeros(1)
        return video, video_summ, text, mask_video, mask_video_summ, mask_text, video_label, text_label, article_sentence, highlight, video_to_text_mask, text_to_video_mask

def worker_init_fn(worker_id):
    """
    Re-seed each worker process to preserve reproducibility
    """
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
    return

def my_collate_fn(batch):
    batched_output_list = []
    for i in range(len(batch[0])):
        batched_output = [item[i] for item in batch]
        batched_output_list.append(batched_output)
    return batched_output_list


In [None]:
import os
import logging
import random
import yaml
from pathlib import Path
from os import PathLike
from typing import Any, List, Dict

import numpy as np
import torch

def set_random_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def init_logger(log_dir: str, log_file: str) -> None:
    logger = logging.getLogger()
    format_str = r'[%(asctime)s] %(message)s'
    logging.basicConfig(
        level=logging.INFO,
        datefmt=r'%Y/%m/%d %H:%M:%S',
        format=format_str
    )
    log_dir = Path(log_dir)
    log_dir.mkdir(parents=True, exist_ok=True)
    fh = logging.FileHandler(str(log_dir / log_file))
    fh.setFormatter(logging.Formatter(format_str))
    logger.addHandler(fh)

def load_yaml(path: PathLike) -> Any:
    with open(path) as f:
        obj = yaml.safe_load(f)
    return obj

def dump_yaml(obj: Any, path: PathLike) -> None:
    with open(path, 'w') as f:
        yaml.dump(obj, f)

class AverageMeter(object):
    def __init__(self, *keys: str):
        self.totals = {key: 0.0 for key in keys}
        self.counts = {key: 0 for key in keys}

    def update(self, **kwargs: float) -> None:
        for key, value in kwargs.items():
            self._check_attr(key)
            self.totals[key] += value
            self.counts[key] += 1

    def __getattr__(self, attr: str) -> float:
        self._check_attr(attr)
        total = self.totals[attr]
        count = self.counts[attr]
        return total / count if count else 0.0

    def _check_attr(self, attr: str) -> None:
        assert attr in self.totals and attr in self.counts



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ImprovedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1.0 / self.temperature)))

    def forward(self, q, k, negatives):
        """
        q: Query embeddings (B, D)
        k: Positive key embeddings (B, D)
        negatives: Negative samples (B, N, D) or (B, D)
        """
        q = q.squeeze(1)  # Ensure 2D shape (B, D)
        k = k.squeeze(1)  # Ensure 2D shape (B, D)

        B, D = q.shape  # Batch size (B) and feature size (D)

        # Ensure negatives are properly shaped
        if negatives.dim() == 2:
            negatives = negatives.unsqueeze(1)  # Convert (B, D) → (B, 1, D)
        elif negatives.dim() == 3 and negatives.shape[1] != 1:
            negatives = negatives.view(B, -1, D)  # Ensure correct shape (B, N, D)

        N = negatives.shape[1]  # Number of negative samples per batch

        # Normalize embeddings with epsilon to prevent NaN errors
        eps = 1e-6
        q = F.normalize(q, dim=-1, eps=eps)  # (B, D)
        k = F.normalize(k, dim=-1, eps=eps)  # (B, D)
        negatives = F.normalize(negatives, dim=-1, eps=eps)  # (B, N, D)

        # Reshape positive keys `k` for batch matrix multiplication
        k = k.view(B, D, 1)  # (B, D, 1)

        # Compute positive similarity: (B, D) x (B, D, 1) → (B, 1)
        pos_sim = torch.bmm(q.unsqueeze(1), k).squeeze(1) / self.temperature  # (B, 1)

        # 🛠 **Fixed negative similarity calculation**
        # Ensure negatives are properly shaped before bmm()
        negatives = negatives.view(B, -1, D)  # (B, N, D)

        # Compute negative similarity: (B, 1, D) x (B, D, N) → (B, 1, N)
        neg_sim = torch.bmm(q.unsqueeze(1), negatives.permute(0, 2, 1)).squeeze(1) / self.temperature  # (B, N)

        # Ensure shape consistency before concatenation
        logits = torch.cat([pos_sim, neg_sim], dim=1)  # (B, 1 + N)

        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(q.device)  # Target index for positives (first column)

        # Compute contrastive loss
        loss = F.cross_entropy(logits, labels)

        return loss


In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import torch.nn.utils.rnn as rnn_utils

def calc_cls_loss(pred: torch.Tensor,
                  target: torch.Tensor,
                  mask: torch.Tensor = None,
                  ) -> torch.Tensor:
    """Compute classification loss on both positive and negative samples.

    :param pred: Predicted class. Sized [B, N].
    :param target: Class target where 1 marks positive, and 0
        marks ignored. Sized [B, N].
    :param kind: Loss type. Choose from (focal, cross-entropy).
    :param mask: indicts the valid segments for each video
    :return: Scalar loss value.
    """

    pred = torch.sigmoid(pred)
    pred = torch.stack([1 - pred, pred], dim=-1)
    mask = mask.to(torch.bool)
    loss = focal_loss(pred, target, reduction='none')
    loss = loss[mask, :]
    loss = torch.mean(loss)
    return loss


def focal_loss(pred: torch.Tensor,
               target: torch.Tensor,
               alpha: float = 0.25,
               gamma: float = 2,
               reduction: str = 'sum'
               ) -> torch.Tensor:
    """Compute focal loss for binary classification.
        FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)

    :param pred: Predicted confidence. Sized [B, N, D].
    :param target: Ground truth target. Sized [B, N].
    :param alpha: Alpha parameter in focal loss.
    :param gamma: Gamma parameter in focal loss.
    :param reduction: Aggregation type. Choose from (sum, mean, none).
    :return: Scalar loss value.
    """
    B, _, num_classes = pred.shape
    t = F.one_hot(target, num_classes)

    p_t = pred * t + (1 - pred) * (1 - t)
    alpha_t = alpha * t + (1 - alpha) * (1 - t)
    fl = -alpha_t * (1 - p_t).pow(gamma) * p_t.clamp(min=1e-7).log()

    ## TODO: update the sum to mean aross the batch axis
    if reduction == 'sum':
        fl = fl.sum()
    elif reduction == 'mean':
        fl = fl.mean()
    elif reduction == 'none':
        pass
    else:
        raise ValueError(f'Invalid reduction mode {reduction}')

    return fl


def iou_offset(offset_a: torch.Tensor,
               offset_b: torch.Tensor,
               eps: float = 1e-8
               ) -> torch.Tensor:
    """Compute IoU offsets between multiple offset pairs.

    :param offset_a: Offsets of N positions. Sized [N, 2].
    :param offset_b: Offsets of N positions. Sized [N, 2].
    :param eps: Small floating value to prevent division by zero.
    :return: IoU values of N positions. Sized [N].
    """
    left_a, right_a = offset_a[:, 0], offset_a[:, 1]
    left_b, right_b = offset_b[:, 0], offset_b[:, 1]

    length_a = left_a + right_a
    length_b = left_b + right_b

    intersect = torch.min(left_a, left_b) + torch.min(right_a, right_b)
    intersect[intersect < 0] = 0
    union = length_a + length_b - intersect
    union[union <= 0] = eps

    iou = intersect / union
    return iou


def calc_loc_loss(pred_loc_batch: torch.Tensor,
                  test_loc_batch: torch.Tensor,
                  cls_label: torch.Tensor,
                  kind: str = 'soft-iou',
                  eps: float = 1e-8
                  ) -> torch.Tensor:
    """Compute soft IoU loss for regression only on positive samples.

    :param pred_loc_batch: Predicted offsets. Sized [B, N, 2].
    :param test_loc_batch: Ground truth offsets. Sized [B, N, 2].
    :param cls_label: Class label specifying positive samples.
    :param kind: Loss type. Choose from (soft-iou, smooth-l1).
    :param eps: Small floating value to prevent division by zero.
    :return: Scalar loss value.
    """
    cls_label = cls_label.to(torch.bool)
    batch_size = cls_label.shape[0]

    loss_sum = 0
    for i in range(batch_size):
        pred_loc = pred_loc_batch[i, cls_label[i]]
        test_loc = test_loc_batch[i, cls_label[i]]

        if kind == 'soft-iou':
            iou = iou_offset(pred_loc, test_loc)
            loss = -torch.log(iou + eps).mean()
        elif kind == 'smooth-l1':
            loss = F.smooth_l1_loss(pred_loc, test_loc)
        else:
            raise ValueError(f'Invalid loss type {kind}')
        loss_sum += loss

    loss = loss_sum / batch_size
    return loss


def calc_ctr_loss(pred_batch, test_batch, pos_mask):
    pos_mask = pos_mask.to(torch.bool) #[B, T]
    batch_size = pos_mask.shape[0]

    loss_sum = 0
    for i in range(batch_size):
        pred = pred_batch[i, pos_mask[i]] #[M]
        test = test_batch[i, pos_mask[i]] #[M]
        loss = F.binary_cross_entropy(pred, test)
        loss_sum += loss

    loss = loss_sum / batch_size
    return loss

@torch.no_grad()
def calc_text_rouge(article_sentence_list, highlight_list, selected_sentence_index_list, dataset=None, rouge=None):
    batch_size = len(selected_sentence_index_list)

    R1_sum = 0
    R2_sum = 0
    RL_sum = 0
    for i in range(batch_size):
        sorted_index_list = sorted(selected_sentence_index_list[i])
        selected_sentence_list = []
        for selected_sentence_index in sorted_index_list:
            selected_sentence_list.append(article_sentence_list[i][selected_sentence_index])

        evaluated_sentence = ' '.join(selected_sentence_list)
        if isinstance(highlight_list[i], list):
            reference_sentence = ' '.join(highlight_list[i])
        elif isinstance(highlight_list[i], str):
            reference_sentence = highlight_list[i]
        scores = rouge.score(evaluated_sentence, reference_sentence)
        R1_sum += scores['rouge1'][2]
        R2_sum += scores['rouge2'][2]
        RL_sum += scores['rougeLsum'][2]

    R1_mean = R1_sum / batch_size
    R2_mean = R2_sum / batch_size
    RL_mean = RL_sum / batch_size
    return R1_mean, R2_mean, RL_mean

@torch.no_grad()
def calc_video_cos(video, gt_summ, keyframe_index_list, mask_video_summ=None, dataset=None):
    batch_size = len(keyframe_index_list)
    gt_summ = F.normalize(gt_summ, dim=-1)

    cos_sim_sum = 0
    for i in range(batch_size):
        if dataset == 'Daily_Mail':
            pred_summ = video[i][keyframe_index_list[i]]
            pred_summ = F.normalize(pred_summ, dim=1)
            sim_mat = gt_summ[i, mask_video_summ[i]] @ pred_summ.permute(1, 0)
            sim_mat = sim_mat.detach().cpu().numpy()
        elif dataset == 'BLiSS':
            pred_summ = F.normalize(video[i], dim=1)
            sim_mat = gt_summ[i, mask_video_summ[i]] @ pred_summ.permute(1, 0)
            sim_mat = sim_mat - torch.min(sim_mat)
            sim_mat = sim_mat / torch.max(sim_mat).clamp(min=1e-6)
            sim_mat = sim_mat[:, keyframe_index_list[i]]
            sim_mat = sim_mat.detach().cpu().numpy()

        # select the largest-K pairwise cosine simialrity (K = num_key_frame)
        num_key_frame = len(keyframe_index_list[i])
        match_mat = np.zeros((num_key_frame, num_key_frame), dtype=int)
        sorted_index = np.dstack(np.unravel_index(np.argsort(-sim_mat.ravel()), sim_mat.shape))[0] #[N*N, 2]
        select_key_frame_count = 0
        for j in range(sorted_index.shape[0]):
            m, n = sorted_index[j]
            if not match_mat[m, :].any() and not match_mat[:, n].any():
                match_mat[m, n] = 1
                select_key_frame_count += 1
            if select_key_frame_count >= num_key_frame:
                break

        cos_sim = np.sum(sim_mat * match_mat) / np.sum(match_mat)
        cos_sim_sum += cos_sim

    cos_sim_mean = cos_sim_sum / batch_size
    return cos_sim_mean


class NCE(nn.Module):
    def __init__(self):
        super(NCE, self).__init__()
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward(self, q, k, neg, device='cuda:0'):
        q = F.normalize(q, dim=1) #[1, C]
        k = F.normalize(k, dim=1) #[1, C]
        neg = F.normalize(neg, dim=1) #[T, C]
        l_pos = q @ k.T #[1, 1]
        l_neg = q @ neg.T #[1, T]
        logits = torch.cat([l_pos, l_neg], dim=1) #[1, 1 + T]
        logits *= self.logit_scale #[1, 1 + T]

        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)
        loss = F.cross_entropy(logits, labels)
        return loss


class Dual_Contrastive_Loss(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.video_contrast = ImprovedContrastiveLoss()
        self.text_contrast = ImprovedContrastiveLoss()

    def forward(self, contrastive_pairs):
        if len(contrastive_pairs) == 0:
            return torch.zeros(1).cuda(), torch.zeros(1).cuda()

        cls_video = contrastive_pairs['cls_video']  # (B, D)
        cls_text = contrastive_pairs['cls_text']  # (B, D)
        key_video_list = contrastive_pairs['key_video_list']
        nonkey_video_list = contrastive_pairs['nonkey_video_list']
        key_text_list = contrastive_pairs['key_text_list']
        nonkey_text_list = contrastive_pairs['nonkey_text_list']

        B = cls_video.shape[0]
        device = cls_video.device

# Pad the nonkey_video_list and nonkey_text_list to have the same shape
        nonkey_video_list = rnn_utils.pad_sequence(nonkey_video_list, batch_first=True, padding_value=0)
        nonkey_text_list = rnn_utils.pad_sequence(nonkey_text_list, batch_first=True, padding_value=0)
        # Compute inter-sample contrastive loss (video <-> text alignment)
        # Remove the unnecessary torch.stack call
        inter_video_loss = self.video_contrast(cls_video, cls_text, nonkey_video_list)
        inter_text_loss = self.text_contrast(cls_text, cls_video, nonkey_text_list)
        inter_contrastive_loss = (inter_video_loss + inter_text_loss) / 2


        # Compute intra-sample contrastive loss (key vs non-key embeddings)
        intra_contrastive_loss = 0
        for i in range(B):
            intra_video_loss = self.video_contrast(
                torch.mean(key_video_list[i], dim=0, keepdim=True),
                torch.mean(key_text_list[i], dim=0, keepdim=True),
                nonkey_video_list[i]
            )
            intra_text_loss = self.text_contrast(
                torch.mean(key_text_list[i], dim=0, keepdim=True),
                torch.mean(key_video_list[i], dim=0, keepdim=True),
                nonkey_text_list[i]
            )
            intra_contrastive_loss += (intra_video_loss + intra_text_loss) / 2

        intra_contrastive_loss /= B
        return inter_contrastive_loss, intra_contrastive_loss


In [None]:
import torch
from torch import nn
from torch.nn import functional as F

import numpy as np
from scipy import ndimage

class MoELayer(nn.Module):
    """Mixture of Experts Layer with Load Balancing Loss"""
    def __init__(self, num_experts, hidden_size, dropout=0.1, ratio=4):
        super().__init__()
        self.experts = nn.ModuleList([
            FFN(hidden_size, p=dropout, ratio=ratio) for _ in range(num_experts)
        ])
        self.gate = nn.Linear(hidden_size, num_experts)
        self.num_experts = num_experts
        self.balancing_coef = 0.01  # For load balancing loss

    def forward(self, x):
        # Gating weights [B, N, num_experts]
        gates = torch.softmax(self.gate(x), dim=-1)

        # Expert outputs [B, N, num_experts, C]
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=2)

        # Weighted sum [B, N, C]
        out = torch.einsum('bnec,bne->bnc', expert_outputs, gates)

        # Load balancing loss (auxiliary loss)
        mask = (x.abs().sum(dim=-1) > 0)  # Ignore padding
        importance = gates[mask].sum(dim=0)  # [num_experts]
        loss = (importance.var() / (importance.mean()**2 + 1e-10)) * self.balancing_coef

        return out, loss

class MultiHeadAttention(nn.Module):
    def __init__(self,
                 dims,
                 k_dims=None,
                 v_dims=None,
                 h_dims=None,
                 o_dims=None,
                 heads=8,
                 p=0.1,
                 bias=True):
        super(MultiHeadAttention, self).__init__()

        self._q_dims = dims
        self._k_dims = k_dims or dims
        self._v_dims = v_dims or dims
        self._h_dims = h_dims or dims
        self._o_dims = o_dims or dims
        self._heads = heads
        self._p = p
        self._bias = bias
        self._head_dims = self._h_dims // heads

        self.q = nn.Linear(self._q_dims, self._h_dims, bias=bias)
        self.k = nn.Linear(self._k_dims, self._h_dims, bias=bias)
        self.v = nn.Linear(self._v_dims, self._h_dims, bias=bias)
        self.m = nn.Linear(self._h_dims, self._o_dims, bias=bias)

        self.drop1 = nn.Dropout(p)
        self.drop2 = nn.Dropout(p)

        self.reset_parameters()

    def __repr__(self):
        return ('{}(q_dims={}, k_dims={}, v_dims={}, h_dims={}, o_dims={}, '
                'heads={}, p={}, bias={})'.format(self.__class__.__name__,
                                                  self._q_dims, self._k_dims,
                                                  self._v_dims, self._h_dims,
                                                  self._o_dims, self._heads,
                                                  self._p, self._bias))

    def reset_parameters(self):
        for m in (self.q, self.k, self.v, self.m):
            nn.init.xavier_normal_(m.weight, gain=1.0)
            if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, q, k=None, v=None, mask=None):
        v = v if torch.is_tensor(v) else k if torch.is_tensor(k) else q
        k = k if torch.is_tensor(k) else q

        q = self.q(q).transpose(0, 1).contiguous()
        k = self.k(k).transpose(0, 1).contiguous()
        v = self.v(v).transpose(0, 1).contiguous()

        b = q.size(1) * self._heads

        q = q.view(-1, b, self._head_dims).transpose(0, 1)
        k = k.view(-1, b, self._head_dims).transpose(0, 1)
        v = v.view(-1, b, self._head_dims).transpose(0, 1)

        att = torch.bmm(q, k.transpose(1, 2)) / self._head_dims**0.5

        if mask is not None:
            mask = torch.where(mask > 0, .0, float('-inf'))
            mask = mask.repeat_interleave(self._heads, dim=0)
            att += mask

        att = att.softmax(-1)

        if self.drop1 is not None:
            att = self.drop1(att)

        m = torch.bmm(att, v).transpose(0, 1).contiguous()
        m = m.view(m.size(0), -1, self._h_dims).transpose(0, 1)
        m = self.m(m)

        if self.drop2 is not None:
            m = self.drop2(m)

        return m

class FFN(nn.Module):
    def __init__(self, num_input, p=0.1, ratio=4):
        super().__init__()
        self.fc1 = nn.Linear(num_input, num_input * ratio)
        self.act = nn.GELU()
        self.drop1 = nn.Dropout(p)
        self.fc2 = nn.Linear(num_input * ratio, num_input)
        self.drop2 = nn.Dropout(p)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

class MultiWayTransformer(nn.Module):
    def __init__(self, num_hidden, dropout_attn=0.1, num_experts=4):
        super().__init__()
        self.norm1_fused = nn.LayerNorm(num_hidden)
        self.attn_fusion = MultiHeadAttention(num_hidden, p=dropout_attn)

        # Replace FFNs with MoE Layers
        self.norm2_video = nn.LayerNorm(num_hidden)
        self.moe_video = MoELayer(num_experts, num_hidden, dropout_attn)

        self.norm2_text = nn.LayerNorm(num_hidden)
        self.moe_text = MoELayer(num_experts, num_hidden, dropout_attn)

        self.num_experts = num_experts

    def forward(self, fused, mask_fused, N_video, N_text):
        residual = fused

        # Cross-modal attention
        fused = self.norm1_fused(fused)
        fused = self.attn_fusion(fused, fused, fused, mask=mask_fused)
        residual = residual + fused

        residual_video, residual_text = torch.split(residual, [N_video, N_text], dim=1)

        # Video MoE branch
        video = self.norm2_video(residual_video)
        video, video_moe_loss = self.moe_video(video)
        residual_video = residual_video + video

        # Text MoE branch
        text = self.norm2_text(residual_text)
        text, text_moe_loss = self.moe_text(text)
        residual_text = residual_text + text

        # Total MoE auxiliary loss
        moe_loss = video_moe_loss + text_moe_loss

        return residual_video, residual_text, moe_loss


# For Daily_Mail/CNN datasets
class Model_MSMO(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        num_input_video = args.get('num_input_video')
        num_input_text = args.get('num_input_text')
        num_hidden = args.get('num_hidden')

        self.ratio = args.get('ratio')

        self.proj_fc_video = nn.Sequential(
                                nn.Linear(num_input_video, num_hidden, bias=True),
                                nn.Dropout(args.get('dropout_video')),
                            )
        self.proj_fc_text = nn.Sequential(
                                nn.Linear(num_input_text, num_hidden, bias=True),
                                nn.Dropout(args.get('dropout_text')),
                            )

        self.pos_embed_video = nn.Parameter(torch.zeros(1, 5000, num_hidden))
        self.pos_embed_text = nn.Parameter(torch.zeros(1, 5000, num_hidden))
        self.type_video = nn.Parameter(torch.zeros(1, 1, num_hidden))
        self.type_text = nn.Parameter(torch.zeros(1, 1, num_hidden))
        self.cls_token_video = nn.Parameter(torch.zeros(1, 1, num_hidden))
        self.cls_token_text = nn.Parameter(torch.zeros(1, 1, num_hidden))

        self.cls_mask_video = torch.ones([1, 1])
        self.cls_mask_text = torch.ones([1, 1])

        self.multiway_list = nn.ModuleList([MultiWayTransformer(num_hidden, dropout_attn=args.get('dropout_attn'))] * args.get('num_layers'))

        self.norm_video = nn.LayerNorm(num_hidden)
        self.norm_text = nn.LayerNorm(num_hidden)

        self.fc_video = nn.Sequential(
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU(True),
            nn.Dropout(args.get('dropout_fc')),
            nn.Linear(num_hidden, 1),
        )
        self.fc_text = nn.Sequential(
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU(True),
            nn.Dropout(args.get('dropout_fc')),
            nn.Linear(num_hidden, 1),
        )

        self.num_layers = args.get('num_layers')

        nn.init.trunc_normal_(self.pos_embed_video, std=.02)
        nn.init.trunc_normal_(self.pos_embed_text, std=.02)
        nn.init.trunc_normal_(self.type_video, std=.02)
        nn.init.trunc_normal_(self.type_text, std=.02)
        nn.init.trunc_normal_(self.cls_token_video, std=.02)
        nn.init.trunc_normal_(self.cls_token_text, std=.02)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def select_contrastive_embedding(self, score, embedding, mask, label):
        B = score.shape[0]

        key_embedding_list = []
        nonkey_embedding_list = []
        for i in range(B):
            length = torch.sum(mask[i].to(torch.long))
            key_embedding_num = max(1, length // self.ratio)
            nonkey_embedding_num = max(1, length // self.ratio)

            key_embedding_index = label[i].to(torch.bool)
            key_embedding = embedding[i, key_embedding_index]

            key_embedding_index_expand = ndimage.binary_dilation(label[i].cpu().detach().numpy(), iterations=4).astype(np.int32)
            key_embedding_index_expand = torch.from_numpy(key_embedding_index_expand)

            score_i = score[i, :length]
            score_i = F.softmax(score_i, dim=-1)

            _, idx_DESC = score_i.sort(descending=True)

            non_key_embedding_index = []
            for j in range(idx_DESC.shape[0]):
                if key_embedding_index_expand[idx_DESC[j]] == 0:
                    non_key_embedding_index.append(idx_DESC[j].item())
                if len(non_key_embedding_index) >= nonkey_embedding_num:
                    break

            if len(non_key_embedding_index) == 0:
                non_key_embedding_index.append(idx_DESC[-1])

            nonkey_embedding = embedding[i, non_key_embedding_index]

            key_embedding_list.append(key_embedding)
            nonkey_embedding_list.append(nonkey_embedding)
        return key_embedding_list, nonkey_embedding_list


    def forward(self, **kwargs):
        video = kwargs['video']
        text = kwargs['text']
        mask_video = kwargs['mask_video']
        mask_text = kwargs['mask_text']
        video_label = kwargs['video_label']
        text_label = kwargs['text_label']

        B = video.shape[0]
        video = self.proj_fc_video(video)
        text = self.proj_fc_text(text)

        # prepend the [CLSV] and [CLST] tokens to the video and text feature sequences
        video = torch.cat([self.cls_token_video.expand(B, -1, -1), video], dim=1)
        text = torch.cat([self.cls_token_text.expand(B, -1, -1), text], dim=1)
        mask_video = torch.cat([self.cls_mask_video.expand(B, -1).to(mask_video), mask_video], dim=1)
        mask_text = torch.cat([self.cls_mask_text.expand(B, -1).to(mask_text), mask_text], dim=1)

        # add positional embedding
        B, N_video, C = video.shape
        B, N_text, C = text.shape
        video = video + self.pos_embed_video[:, :N_video, :] + self.type_video
        text = text + self.pos_embed_text[:, :N_text, :] + self.type_text
        # Feature Concatenation (Early Fusion) :
        fused = torch.cat([video, text], dim=1) # Text and video Are fused (concatenated in this cell) !!!!!
        mask_fused = torch.cat([mask_video, mask_text], dim=1) #[B, N_video+N_text]
        mask_fused = mask_fused.unsqueeze(1).expand(-1, N_video+N_text, -1) #[B, N_video+N_text, N_video+N_text]
        # multiway transformer layers
        total_moe_loss = 0
        for i in range(self.num_layers):
            video, text, moe_loss_layer = self.multiway_list[i](fused, mask_fused, N_video, N_text)
            fused = torch.cat([video, text], dim=1)
            total_moe_loss += moe_loss_layer  # Accumulate MoE loss


        video = self.norm_video(video)
        text = self.norm_text(text)

        cls_video, video = torch.split(video, [1, N_video-1], dim=1)
        cls_text, text = torch.split(text, [1, N_text-1], dim=1)

        pred_video = self.fc_video(video).squeeze(-1) #[B, N]
        pred_text = self.fc_text(text).squeeze(-1) #[B, N]

        # select contrastive pairs for the intra-sample constrastive loss
        key_video_list, nonkey_video_list = self.select_contrastive_embedding(pred_video, video, mask_video[:, 1:], video_label)
        key_text_list, nonkey_text_list = self.select_contrastive_embedding(pred_text, text, mask_text[:, 1:], text_label)

        contrastive_pairs = {
            'key_video_list': key_video_list,
            'nonkey_video_list': nonkey_video_list,
            'key_text_list': key_text_list,
            'nonkey_text_list': nonkey_text_list,
            'cls_video': cls_video,
            'cls_text': cls_text,
        }

        return pred_video, pred_text, contrastive_pairs, total_moe_loss


In [None]:
import logging
import time
import os
import numpy as np

import torch
import torch.nn as nn
import torch.utils.data
from torch.nn.utils.rnn import pad_sequence


from rouge_score import rouge_scorer

logger = logging.getLogger()

def train_msmo(args):
    batch_time = AverageMeter('time')
    data_time = AverageMeter('time')

    #if args.get('dataset') == 'BLiSS':
     #   model = Model_BLiSS(args=args)
    #elif args.get('dataset') in ['Daily_Mail', 'CNN']:
    model = Model_MSMO(args=args)
    # Lists to store loss values for plotting
    tl_total = []
    learning_Gap=[]
    tl_text = []
    tl_vid = []
    tl_inter_c = []
    tl_itra_c = []


    model = model.to(args.get('device'))
    calc_contrastive_loss = Dual_Contrastive_Loss().to(args.get('device'))

    parameters = [p for p in model.parameters() if p.requires_grad] + \
                    [p for p in calc_contrastive_loss.parameters() if p.requires_grad]

    optimizer = torch.optim.Adam(parameters, lr=args.get('lr'), weight_decay=args.get('weight_decay'))

    os.makedirs('{}/checkpoint'.format(args.get('model_dir')), exist_ok=True)

    rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], use_stemmer=True, split_summaries=True)

    max_train_R1 = max_train_R2 = max_train_RL = max_train_cos = 0
    max_val_R1 = max_val_R2 = max_val_RL = max_val_cos = 0
    best_val_epoch = 0

#    if args.get('dataset') in ['Daily_Mail', 'CNN']:
#        dataset_name = 'MSMODataset'
#    elif args.get('dataset') in ['BLiSS']:
#        dataset_name = 'BLiSSDataset'

    train_set = MSMODataset(mode='train', args=args)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.get('batch_size'), shuffle=True, num_workers=args.get('num_workers'),
                                                drop_last=False, pin_memory=True,
                                                worker_init_fn=worker_init_fn, collate_fn=my_collate_fn)
    val_set = MSMODataset(mode='test', args=args)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.get('batch_size'), shuffle=False, num_workers=args.get('num_workers'),
                                                drop_last=False, pin_memory=True,
                                                worker_init_fn=worker_init_fn, collate_fn=my_collate_fn)

    checkpoint_path = None
    if args.get('checkpoint') and args.get('test'):
        checkpoint_path = '{}/model_best_text.pt'.format(args.get('checkpoint'))
        print(f"load checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location='cpu',weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        val_R1, val_R2, val_RL, _ = evaluate_msmo(model, val_loader, args, epoch=0,rouge=rouge)

        checkpoint_path = '{}/model_best_video.pt'.format(args.get('checkpoint'))
        print(f"load checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
        _, _, _, val_cos = evaluate_msmo(model, val_loader, args, epoch=0,rouge=rouge)

        print(f'R1: {val_R1:.4f} R2: {val_R2:.4f} RL: {val_RL:.4f} Cos: {val_cos:.4f}')
        return val_R1, val_R2, val_RL, val_cos, best_val_epoch, max_train_R1, max_train_R2, max_train_RL, max_train_cos

    print('\n' + str(model))

    for epoch in range(args.get('start_epoch'), args.get('max_epoch')):
        model.train()
        stats = AverageMeter('total_loss', 'text_loss', 'video_loss', 'inter_contrastive_loss', 'intra_contrastive_loss', 'R1', 'R2', 'RL', 'cos')

        data_length = len(train_loader)
        end = time.time()
        for k, (video_list, video_summ_list, text_list, \
                mask_video_list, mask_video_summ_list, mask_text_list, \
                video_label_list, text_label_list, article_segment_list, highlight_list, \
                video_to_text_mask_list, text_to_video_mask_list) in enumerate(train_loader):
            data_time.update(time=time.time() - end)

            batch_size = len(video_list)

            video = pad_sequence(video_list, batch_first=True)
            video_summ = pad_sequence(video_summ_list, batch_first=True)
            text = pad_sequence(text_list, batch_first=True)

            mask_video = pad_sequence(mask_video_list, batch_first=True)
            mask_video_summ = pad_sequence(mask_video_summ_list, batch_first=True)
            mask_text = pad_sequence(mask_text_list, batch_first=True)

            video_label = pad_sequence(video_label_list, batch_first=True)
            text_label = pad_sequence(text_label_list, batch_first=True)

            for i in range(len(video_to_text_mask_list)):
                video_to_text_mask_list[i] = video_to_text_mask_list[i].to(args.get('device'))
                text_to_video_mask_list[i] = text_to_video_mask_list[i].to(args.get('device'))

            video, video_summ, text = video.to(args.get('device')), video_summ.to(args.get('device')), text.to(args.get('device'))
            mask_video, mask_video_summ, mask_text = mask_video.to(args.get('device')), mask_video_summ.to(args.get('device')), mask_text.to(args.get('device'))

            video_label = video_label.to(args.get('device')) #[B, T]
            text_label = text_label.to(args.get('device')) #[B, T]

            pred_video, pred_text, contrastive_pairs,total_moe_loss = model(video=video, text=text, \
                                                                mask_video=mask_video, mask_text=mask_text, \
                                                                video_label=video_label, text_label=text_label, \
                                                                video_to_text_mask_list=video_to_text_mask_list, \
                                                                text_to_video_mask_list=text_to_video_mask_list)
           # moe_loss = model.multiway_list[i].moe_loss  # Sum across layers
            # Access the total MoE loss returned by the model forward pass.
            moe_loss = total_moe_loss

            num_frame_selected = torch.sum(video_label, dim=-1)
            num_sentence_selected = torch.sum(text_label, dim=-1)
            #print('predicted video :{}'.formatpred_video)
            #print(pred_text)
            mask_video_bool = mask_video.to(torch.bool)
            mask_video_summ_bool = mask_video_summ.to(torch.bool)
            mask_text_bool = mask_text.to(torch.bool)

            # select frames and sentences with top-k highest importance score as predicted video and text summary
            keyframe_index_list = []
            keysentence_index_list = []
            for i in range(batch_size):
                keyframe_index_list.append(torch.topk(pred_video[i, mask_video_bool[i]], k=num_frame_selected[i])[1].tolist())
                keysentence_index_list.append(torch.topk(pred_text[i, mask_text_bool[i]], k=num_sentence_selected[i])[1].tolist())

            text_loss = calc_cls_loss(pred_text, text_label, mask=mask_text)
            if args.get('dataset') in ['Daily_Mail', 'BLiSS']:
                video_loss = calc_cls_loss(pred_video, video_label, mask=mask_video)
            else:
                video_loss = torch.zeros(1).to(text_loss)

            inter_contrastive_loss, intra_contrastive_loss = calc_contrastive_loss(contrastive_pairs)

            inter_contrastive_loss = inter_contrastive_loss * args.get('lambda_contrastive_inter')
            intra_contrastive_loss = intra_contrastive_loss * args.get('lambda_contrastive_intra')
            #loss = video_loss + text_loss + inter_contrastive_loss + intra_contrastive_loss
                ###########################################################
            # Add MoE loss here (new code)
            total_loss = (
                video_loss +
                text_loss +
                inter_contrastive_loss * args.get('lambda_contrastive_inter', 1.0) +
                intra_contrastive_loss * args.get('lambda_contrastive_intra', 1.0) +
                moe_loss * args.get('lambda_moe', 0.1)  # Scale MoE loss
            )
            ###########################################################


#############################################################################################""

            if args.get('dataset') in ['Daily_Mail', 'BLiSS']:
                video_cos = calc_video_cos(video, video_summ, keyframe_index_list, mask_video_summ=mask_video_summ_bool, dataset=args.get('dataset'))
            else:
                video_cos = 0
            text_R1, text_R2, text_RL = calc_text_rouge(article_segment_list, highlight_list, keysentence_index_list, dataset=args.get('dataset'), rouge=rouge)

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            stats.update(total_loss=total_loss.item(), text_loss=text_loss.item(), video_loss=video_loss.item(),
                            inter_contrastive_loss=inter_contrastive_loss.item(), intra_contrastive_loss=intra_contrastive_loss.item(),
                            R1=text_R1, R2=text_R2, RL=text_RL, cos=video_cos)

            batch_time.update(time=time.time() - end)
            end = time.time()

            if (k + 1) % args.get('print_freq') == 0:
                print(f'[Train] Epoch: {epoch+1}/{args["max_epoch"]} Iter: {k+1}/{data_length} args["lr"] '
                            f'Time: {batch_time.time:.3f} Data: {data_time.time:.3f} '
                            f'Loss: {stats.text_loss:.4f}/{stats.video_loss:.4f}/{stats.inter_contrastive_loss:.4f}/{stats.intra_contrastive_loss:.4f}/{stats.total_loss:.4f} '
                            f'R1: {stats.R1:.4f} R2: {stats.R2:.4f} RL: {stats.RL:.4f} Cos: {stats.cos:.4f}')

        max_train_R1 = max(stats.R1, max_train_R1)
        max_train_R2 = max(stats.R2, max_train_R2)
        max_train_RL = max(stats.RL, max_train_RL)
        max_train_cos = max(stats.cos, max_train_cos)
        tl_total.append(np.mean(stats.total_loss))  # Append training loss to list
        tl_text.append(np.mean(stats.text_loss))
        tl_vid.append(np.mean(stats.video_loss))
        tl_inter_c.append(np.mean(stats.inter_contrastive_loss))
        tl_itra_c.append(np.mean(stats.intra_contrastive_loss))

        print(f'[Train] Epoch: {epoch+1}/{args["max_epoch"]} '
                    f'R1: {stats.R1:.4f}/{max_train_R1:.4f} '
                    f'R2: {stats.R2:.4f}/{max_train_R2:.4f} '
                    f'RL: {stats.RL:.4f}/{max_train_RL:.4f} '
                    f'Cos: {stats.cos:.4f}/{max_train_cos:.4f}\n'
        )

        args.get('writer').add_scalar(f'Train/max_train_R1', max_train_R1, epoch+1)
        args.get('writer').add_scalar(f'Train/max_train_R2', max_train_R2, epoch+1)
        args.get('writer').add_scalar(f'Train/max_train_RL', max_train_RL, epoch+1)
        args.get('writer').add_scalar(f'Train/max_train_cos', max_train_cos, epoch+1)
        args.get('writer').add_scalar(f'Train/train_R1', stats.R1, epoch+1)
        args.get('writer').add_scalar(f'Train/train_R2', stats.R2, epoch+1)
        args.get('writer').add_scalar(f'Train/train_RL', stats.RL, epoch+1)
        args.get('writer').add_scalar(f'Train/train_cos', stats.cos, epoch+1)

        save_checkpoint = {
            'epoch': epoch+1,
            'model_state_dict': model.state_dict(),
            'max_val_R1': max_val_R1,
            'max_val_R2': max_val_R2,
            'max_val_RL': max_val_RL,
            'max_val_cos': max_val_cos,
        }

        if (epoch + 1) % args.get('eval_freq') == 0:
            val_R1, val_R2, val_RL, val_cos,vl_inter_c,vl_itra_c,vl_text,vl_vid,vl_total = evaluate_msmo(model, val_loader, args, epoch=epoch,rouge=rouge)
            max_val_R2 = max(val_R2, max_val_R2)
            max_val_RL = max(val_RL, max_val_RL)
            if max_val_R1 < val_R1:
                max_val_R1 = max(val_R1, max_val_R1)
                best_val_epoch = epoch + 1
                torch.save(save_checkpoint, '{}/checkpoint/model_best_text.pt'.format(args.get('model_dir')))
            if max_val_cos < val_cos:
                max_val_cos = max(val_cos, max_val_cos)
                torch.save(save_checkpoint, '{}/checkpoint/model_best_video.pt'.format(args.get('model_dir')))

            print(f'[Eval]  Epoch: {epoch+1}/{args["max_epoch"]} '
                        f'R1: {val_R1:.4f}/{max_val_R1:.4f} '
                        f'R2: {val_R2:.4f}/{max_val_R2:.4f} '
                        f'RL: {val_RL:.4f}/{max_val_RL:.4f} '
                        f'Cos: {val_cos:.4f}/{max_val_cos:.4f}\n\n'
            )

            args.get('writer').add_scalar(f'Val/max_val_R1', max_val_R1, epoch+1)
            args.get('writer').add_scalar(f'Val/max_val_R2', max_val_R2, epoch+1)
            args.get('writer').add_scalar(f'Val/max_val_RL', max_val_RL, epoch+1)
            args.get('writer').add_scalar(f'Val/max_val_cos', max_val_cos, epoch+1)
            args.get('writer').add_scalar(f'Val/val_R1', val_R1, epoch+1)
            args.get('writer').add_scalar(f'Val/val_R2', val_R2, epoch+1)
            args.get('writer').add_scalar(f'Val/val_RL', val_RL, epoch+1)
            args.get('writer').add_scalar(f'Val/val_cos', val_cos, epoch+1)

        args.get('writer').add_scalar(f'Train/loss', stats.total_loss, epoch+1)
        args.get('writer').add_scalar(f'Train/text_loss', stats.text_loss, epoch+1)
        args.get('writer').add_scalar(f'Train/video_loss', stats.video_loss, epoch+1)

    return max_val_R1, max_val_R2, max_val_RL, max_val_cos, best_val_epoch, \
            max_train_R1, max_train_R2, max_train_RL, max_train_cos,tl_total,tl_text,tl_vid,tl_inter_c,tl_itra_c,vl_inter_c,vl_itra_c,vl_text,vl_vid,vl_total


@torch.no_grad()
def evaluate_msmo(model, val_loader, args, epoch=None, mode='train',rouge=None):
    stats = AverageMeter('R1', 'R2', 'RL', 'cos')
    data_length = len(val_loader)
    calc_contrastive_loss = Dual_Contrastive_Loss().to(args.get('device'))
    model.eval()
    vl_total = []
    vl_text = []
    vl_vid = []
    vl_inter_c = []
    vl_itra_c = []
    for k, (video_list, video_summ_list, text_list, \
            mask_video_list, mask_video_summ_list, mask_text_list, \
            video_label_list, text_label_list, article_segment_list, highlight_list, \
            video_to_text_mask_list, text_to_video_mask_list) in enumerate(val_loader):

        batch_size = len(video_list)

        video = pad_sequence(video_list, batch_first=True)
        video_summ = pad_sequence(video_summ_list, batch_first=True)
        text = pad_sequence(text_list, batch_first=True)

        mask_video = pad_sequence(mask_video_list, batch_first=True)
        mask_video_summ = pad_sequence(mask_video_summ_list, batch_first=True)
        mask_text = pad_sequence(mask_text_list, batch_first=True)

        video_label = pad_sequence(video_label_list, batch_first=True)
        text_label = pad_sequence(text_label_list, batch_first=True)

        video, video_summ, text = video.to(args.get('device')), video_summ.to(args.get('device')), text.to(args.get('device'))
        mask_video, mask_video_summ, mask_text = mask_video.to(args.get('device')), mask_video_summ.to(args.get('device')), mask_text.to(args.get('device'))

        video_label = video_label.to(args.get('device')) #[B, T]
        text_label = text_label.to(args.get('device')) #[B, T]

        for i in range(len(video_to_text_mask_list)):
            video_to_text_mask_list[i] = video_to_text_mask_list[i].to(args.get('device'))
            text_to_video_mask_list[i] = text_to_video_mask_list[i].to(args.get('device'))

        pred_video, pred_text, contrastive_pairs,moe_loss = model(video=video, text=text, \
                                                            mask_video=mask_video, mask_text=mask_text, \
                                                            video_label=video_label, text_label=text_label, \
                                                            video_to_text_mask_list=video_to_text_mask_list, \
                                                            text_to_video_mask_list=text_to_video_mask_list)

        num_frame_selected = torch.sum(video_label, dim=-1)
        num_sentence_selected = torch.sum(text_label, dim=-1)
        #print('predicted text ; {}'.format(pred_text))
        #print('predicted video ; {}'.format(pred_video))
        mask_video_bool = mask_video.to(torch.bool)
        mask_video_summ_bool = mask_video_summ.to(torch.bool)
        mask_text_bool = mask_text.to(torch.bool)
        keyframe_index_list = []
        keysentence_index_list = []
        for i in range(batch_size):
            keyframe_index_list.append(torch.topk(pred_video[i, mask_video_bool[i]], k=num_frame_selected[i])[1].tolist())
            keysentence_index_list.append(torch.topk(pred_text[i, mask_text_bool[i]], k=num_sentence_selected[i])[1].tolist())

        if args.get('dataset') in ['Daily_Mail', 'BLiSS']:
            video_cos = calc_video_cos(video, video_summ, keyframe_index_list, mask_video_summ=mask_video_summ_bool, dataset=args.get('dataset'))
        else:
            video_cos = 0
        text_R1, text_R2, text_RL = calc_text_rouge(article_segment_list, highlight_list, keysentence_index_list, dataset=args.get('dataset'), rouge=rouge)
       # print('\nhighlight_list : \n{}'.format(highlight_list))
        #print('\nkeysentence_index_list : \n{}'.format(keysentence_index_list))
        #print('\article_segment_list : \n{}'.format(article_segment_list))
        stats.update(R1=text_R1, R2=text_R2, RL=text_RL, cos=video_cos)
        ## Losses :
        inter_contrastive_loss, intra_contrastive_loss = calc_contrastive_loss(contrastive_pairs)

        inter_contrastive_loss = inter_contrastive_loss * args.get('lambda_contrastive_inter')
        intra_contrastive_loss = intra_contrastive_loss * args.get('lambda_contrastive_intra')
        video_loss = calc_cls_loss(pred_video, video_label, mask=mask_video)
        text_loss = calc_cls_loss(pred_text, text_label, mask=mask_text)
        vl_inter_c.append(inter_contrastive_loss.item())
        vl_itra_c.append(intra_contrastive_loss.item())
        vl_text.append(text_loss.item())
        vl_vid.append(video_loss.item())
        vl_total.append(video_loss.item()+text_loss.item()+inter_contrastive_loss.item()+intra_contrastive_loss.item())
        if (k + 1) % args.get('print_freq') == 0:
            print(f'[Eval]  Epoch: {epoch+1}/{args["max_epoch"]} Iter: {k+1}/{data_length} '
                        f'R1: {stats.R1:.4f} R2: {stats.R2:.4f} RL: {stats.RL:.4f} Cos: {stats.cos:.4f}')
    return stats.R1, stats.R2, stats.RL, stats.cos,vl_inter_c,vl_itra_c,vl_text,vl_vid,vl_total


In [None]:
!pip install torch_xla[tpu] -f https://storage.googleapis.com/tpu-pytorch/wheels/colab.html

Looking in links: https://storage.googleapis.com/tpu-pytorch/wheels/colab.html


In [None]:
import torch_xla
import torch_xla.core.xla_model as xm



In [None]:
args ={ "dataset" : 'Daily_Mail',
        "data_root":'/content/drive/MyDrive/data' ,
       "device" : 'xla',
       "start_epoch":0,
       "num_workers" :2,
       "model_dir" : 'logsV4',
       "log_file" : 'logV1.txt',
       "nms_thresh": 0.4,
       "print_freq":5,
       "eval_freq":1,
       #"checkpoint":'/media/pc/New Volume/Saad/Models/A2Summ/logs/checkpointV1',
       "test" :False,
       "num_feature":512,
        "lr":2e-4,
        "lambda_moe": 0.4,        # MoE loss weight
        "num_experts": 5,         # Number of experts per MoE layer
        "moe_dropout": 0.05,
        "weight_decay": 1e-7,
        "max_epoch": 1,
        "batch_size": 4,
        "seed": 12345,

        "num_input_video": 2048,
        "num_input_text": 768,
        "num_hidden": 256,
        "num_layers": 2,

        "dropout_video": 0.1,
        "dropout_text": 0.1,
        "dropout_attn": 0.1,
        "dropout_fc": 0.5,

        "lambda_contrastive_inter": 0.001,
        "lambda_contrastive_intra": 0.001,
        "ratio": 8,
    }

In [None]:
from torch.utils.tensorboard import SummaryWriter

init_logger(args["model_dir"], args.get('log_file'))
set_random_seed(args.get('seed'))
dump_yaml(args, '{}/args.yml'.format(args["model_dir"]))

logger.info(args)
os.makedirs(args.get('model_dir'), exist_ok=True)
print(args.get('model_dir'))

args["writer"] = SummaryWriter(os.path.join(args["model_dir"], 'tensorboard'))

max_val_R1, max_val_R2, max_val_RL, max_val_cos, best_val_epoch, \
    max_train_R1, max_train_R2, max_train_RL, max_train_cos = train_msmo(args)

logger.info(f'Training done. Val R1: {max_val_R1:.4f}, R2: {max_val_R2:.4f}, RL: {max_val_RL:.4f}, Cos: {max_val_cos:.4f}, Best Epoch:{best_val_epoch}.')
logger.info(f'             Train R1: {max_train_R1:.4f}, R2: {max_train_R2:.4f}, RL: {max_train_RL:.4f}, Cos: {max_train_cos:.4f}.\n\n')



logsV4


100%|██████████| 1207/1207 [00:02<00:00, 415.65it/s]
100%|██████████| 352/352 [00:00<00:00, 574.58it/s]



Model_MSMO(
  (proj_fc_video): Sequential(
    (0): Linear(in_features=2048, out_features=256, bias=True)
    (1): Dropout(p=0.1, inplace=False)
  )
  (proj_fc_text): Sequential(
    (0): Linear(in_features=768, out_features=256, bias=True)
    (1): Dropout(p=0.1, inplace=False)
  )
  (multiway_list): ModuleList(
    (0-1): 2 x MultiWayTransformer(
      (norm1_fused): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (attn_fusion): MultiHeadAttention(q_dims=256, k_dims=256, v_dims=256, h_dims=256, o_dims=256, heads=8, p=0.1, bias=True)
      (norm2_video): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (moe_video): MoELayer(
        (experts): ModuleList(
          (0-3): 4 x FFN(
            (fc1): Linear(in_features=256, out_features=1024, bias=True)
            (act): GELU(approximate='none')
            (drop1): Dropout(p=0.1, inplace=False)
            (fc2): Linear(in_features=1024, out_features=256, bias=True)
            (drop2): Dropout(p=0.1, inplace=

KeyboardInterrupt: 

In [None]:
args ={ "dataset" : 'Daily_Mail',
        "data_root":'/content/drive/MyDrive/data' ,
       "device" : 'cpu',
       "start_epoch":0,
       "num_workers" :2,
       "model_dir" : '/content/logsV3',
       "log_file" : 'logV3.txt',
       "nms_thresh": 0.4,
       "print_freq":5,
       "eval_freq":1,
       "checkpoint":'/content/logsV3/checkpoint',
       "test" :True,
       "num_feature":512,
        "lr":2e-4,
        "weight_decay": 1e-7,
        "max_epoch": 1,
        "batch_size": 4,
        "seed": 12345,

        "num_input_video": 2048,
        "num_input_text": 768,
        "num_hidden": 256,
        "num_layers": 2,

        "dropout_video": 0.1,
        "dropout_text": 0.1,
        "dropout_attn": 0.1,
        "dropout_fc": 0.5,

        "lambda_contrastive_inter": 0.001,
        "lambda_contrastive_intra": 0.001,
        "ratio": 8,
    }

In [None]:
from torch.utils.tensorboard import SummaryWriter

init_logger(args["model_dir"], args.get('log_file'))
set_random_seed(args.get('seed'))
dump_yaml(args, '{}/args.yml'.format(args["model_dir"]))

logger.info(args)
os.makedirs(args.get('model_dir'), exist_ok=True)
print(args.get('model_dir'))

args["writer"] = SummaryWriter(os.path.join(args["model_dir"], 'tensorboard'))

max_val_R1, max_val_R2, max_val_RL, max_val_cos, best_val_epoch, \
    max_train_R1, max_train_R2, max_train_RL, max_train_cos = train_msmo(args)

logger.info(f'Training done. Val R1: {max_val_R1:.4f}, R2: {max_val_R2:.4f}, RL: {max_val_RL:.4f}, Cos: {max_val_cos:.4f}, Best Epoch:{best_val_epoch}.')
logger.info(f'             Train R1: {max_train_R1:.4f}, R2: {max_train_R2:.4f}, RL: {max_train_RL:.4f}, Cos: {max_train_cos:.4f}.\n\n')



/content/logsV3


100%|██████████| 1207/1207 [00:00<00:00, 6058.72it/s]
100%|██████████| 352/352 [00:00<00:00, 6009.18it/s]

load checkpoint from /content/logsV3/checkpoint/model_best_text.pt





FileNotFoundError: [Errno 2] No such file or directory: '/content/logsV3/checkpoint/model_best_text.pt'

In [None]:
model_txt=Model_MSMO(args)
model_vid=Model_MSMO(args)

In [None]:
checkpoint_path = '{}/model_best_text.pt'.format(args.get('checkpoint'))
print(f"load checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model_txt.load_state_dict(checkpoint['model_state_dict'])
#val_R1, val_R2, val_RL, _ = evaluate_msmo(model, val_loader, args, epoch=0)

checkpoint_path = '{}/model_best_video.pt'.format(args.get('checkpoint'))
print(f"load checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model_vid.load_state_dict(checkpoint['model_state_dict'])
#_, _, _, val_cos = evaluate_msmo(model, val_loader, args, epoch=0)

load checkpoint from /media/pc/New Volume/Saad/Models/A2Summ/logs/checkpoint/model_best_text.pt
load checkpoint from /media/pc/New Volume/Saad/Models/A2Summ/logs/checkpoint/model_best_video.pt


  checkpoint = torch.load(checkpoint_path, map_location='cpu')
  checkpoint = torch.load(checkpoint_path, map_location='cpu')


<All keys matched successfully>

In [None]:
print(model_txt)

Model_MSMO(
  (proj_fc_video): Sequential(
    (0): Linear(in_features=2048, out_features=256, bias=True)
    (1): Dropout(p=0.1, inplace=False)
  )
  (proj_fc_text): Sequential(
    (0): Linear(in_features=768, out_features=256, bias=True)
    (1): Dropout(p=0.1, inplace=False)
  )
  (multiway_list): ModuleList(
    (0-1): 2 x MultiWayTransformer(
      (norm1_fused): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (attn_fusion): MultiHeadAttention(q_dims=256, k_dims=256, v_dims=256, h_dims=256, o_dims=256, heads=8, p=0.1, bias=True)
      (norm2_video): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (ffn_video): FFN(
        (fc1): Linear(in_features=256, out_features=1024, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.1, inplace=False)
        (fc2): Linear(in_features=1024, out_features=256, bias=True)
        (drop2): Dropout(p=0.1, inplace=False)
      )
      (norm2_text): LayerNorm((256,), eps=1e-05, elementwise_affine=

In [None]:
# Uninstall the current PyTorch installation (if any).
!pip uninstall -y torch torchvision torchaudio

# Install PyTorch with CUDA support.
# Make sure to select the correct CUDA version for your environment.
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Restart the runtime. This is important for the changes to take effect.
import os
os.kill(os.getpid(), 9)

Found existing installation: torch 2.5.1+cpu
Uninstalling torch-2.5.1+cpu:
  Successfully uninstalled torch-2.5.1+cpu
Found existing installation: torchvision 0.20.1+cpu
Uninstalling torchvision-0.20.1+cpu:
  Successfully uninstalled torchvision-0.20.1+cpu
Found existing installation: torchaudio 2.5.1+cpu
Uninstalling torchaudio-2.5.1+cpu:
  Successfully uninstalled torchaudio-2.5.1+cpu
Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting torch
  Downloading https://download.pytorch.org/whl/cu118/torch-2.6.0%2Bcu118-cp311-cp311-linux_x86_64.whl.metadata (27 kB)
Collecting torchvision
  Downloading https://download.pytorch.org/whl/cu118/torchvision-0.21.0%2Bcu118-cp311-cp311-linux_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio
  Downloading https://download.pytorch.org/whl/cu118/torchaudio-2.6.0%2Bcu118-cp311-cp311-linux_x86_64.whl.metadata (6.6 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch)
  Downloading https://download.pytorch.org/whl/cu118/nvidia_cu