In [60]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from tqdm import tqdm
import random
import numpy as np
import os
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from scipy.ndimage import label as scipy_label

In [61]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [62]:
import torch
import torch.nn as nn

class CRF(nn.Module):
    """
    Реализация Conditional Random Field (CRF) (были проблемы с импортом)
    """
    def __init__(self, num_tags: int, batch_first: bool = True):
        if num_tags <= 0:
            raise ValueError(f"invalid number of tags: {num_tags}")
        super().__init__()
        self.num_tags = num_tags
        self.batch_first = batch_first
        self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
        self.start_transitions = nn.Parameter(torch.empty(num_tags))
        self.end_transitions = nn.Parameter(torch.empty(num_tags))
        nn.init.uniform_(self.transitions, -0.1, 0.1)
        nn.init.uniform_(self.start_transitions, -0.1, 0.1)
        nn.init.uniform_(self.end_transitions, -0.1, 0.1)

    def forward(self, emissions, tags=None, mask=None, reduction: str = 'mean'):
        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            if tags is not None:
                tags = tags.transpose(0, 1)
            if mask is not None:
                mask = mask.transpose(0, 1)
        if mask is None:
            mask = torch.ones_like(tags, dtype=torch.uint8)
        log_likelihood = self._compute_log_likelihood(emissions, tags, mask)
        if reduction == 'sum':
            return -log_likelihood.sum()
        if reduction == 'mean':
            return -log_likelihood.mean()
        return -log_likelihood

    def decode(self, emissions, mask=None):
        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            if mask is not None:
                mask = mask.transpose(0, 1)
        if mask is None:
            mask = torch.ones(emissions.shape[:2], dtype=torch.uint8, device=emissions.device)
        return self._viterbi_decode(emissions, mask)

    def _compute_log_likelihood(self, emissions, tags, mask):
        seq_length, batch_size, _ = emissions.shape
        log_alpha = self._forward_pass(emissions, mask)
        gold_score = self._score_sequence(emissions, tags, mask)    
        return gold_score - log_alpha

    def _forward_pass(self, emissions, mask):
        seq_length, batch_size, _ = emissions.shape
        log_alpha = self.start_transitions + emissions[0]
        for i in range(1, seq_length):
            emit_scores = emissions[i].unsqueeze(1)
            trans_scores = self.transitions.unsqueeze(0)
            alpha_t = log_alpha.unsqueeze(2)
            scores = trans_scores + alpha_t + emit_scores
            log_alpha_next = torch.logsumexp(scores, dim=1)
            mask_t = mask[i].unsqueeze(1).float()
            log_alpha = mask_t * log_alpha_next + (1 - mask_t) * log_alpha
        log_alpha += self.end_transitions
        return torch.logsumexp(log_alpha, dim=1)


    def _score_sequence(self, emissions, tags, mask):
        batch_size = emissions.size(1)
        score = self.start_transitions[tags[0]]
        score += (emissions.gather(2, tags.unsqueeze(2)).squeeze(2) * mask.float()).sum(0)
        for i in range(emissions.size(0) - 1):
            score += self.transitions[tags[i], tags[i+1]] * mask[i+1].float()
        seq_ends = mask.long().sum(dim=0) - 1
        last_tags = tags[seq_ends, torch.arange(batch_size)]
        score += self.end_transitions[last_tags]
        return score
    
    def _viterbi_decode(self, emissions, mask):
        seq_length, batch_size, _ = emissions.shape
        log_delta = self.start_transitions + emissions[0]
        backpointers = []
        for i in range(1, seq_length):
            delta_t = log_delta.unsqueeze(2)
            trans_scores = self.transitions.unsqueeze(0)
            scores = delta_t + trans_scores
            log_delta_next, backpointers_t = torch.max(scores, dim=1)
            log_delta_next += emissions[i]
            mask_t = mask[i].unsqueeze(1).float()
            log_delta = mask_t * log_delta_next + (1 - mask_t) * log_delta
            backpointers.append(backpointers_t)

        log_delta += self.end_transitions
        best_last_tag = torch.argmax(log_delta, dim=1)
        best_path = [best_last_tag]
        for backpointers_t in reversed(backpointers):
            best_last_tag = backpointers_t.gather(1, best_last_tag.unsqueeze(1)).squeeze(1)
            best_path.insert(0, best_last_tag)    
        return torch.stack(best_path).transpose(0, 1).tolist()


class IntroDetectionTransformer(nn.Module):
    def __init__(self, d_model=768, n_heads=8, n_layers=4, num_labels=2, class_weights=None):
        super().__init__()
        self.pos_encoding = nn.Parameter(torch.zeros(1, 60, d_model))
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, batch_first=True, dropout=0.1)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.classifier = nn.Linear(d_model, num_labels)
        self.crf = CRF(num_labels, batch_first=True)
    
        if class_weights is not None:
            self.register_buffer('class_weights', class_weights)
        else:
            self.class_weights = None

    def forward(self, embeddings, labels=None, mask=None):
        x = embeddings + self.pos_encoding
        x = self.transformer(x) 
        logits = self.classifier(x)

        if labels is not None:
            if self.class_weights is not None:
                logits[:, :, 1] = logits[:, :, 1] * self.class_weights[1]
            
            return self.crf(logits, labels.long(), mask=mask)
        else:
            return self.crf.decode(logits, mask=mask)

In [63]:
dataset = VideoDataset('/kaggle/input/introdetectiondataset/train_dataset2.pt')
# labels = [1.0 if any(item["labels"] == 1) else 0.0 for item in dataset.data]
# weights = [20.0 if label == 1 else 1.0 for label in labels]  # Oversampling для заставок
# sampler = WeightedRandomSampler(weights, len(dataset), replacement=True)
dataloader = DataLoader(dataset, batch_size=16)

In [64]:
print(f"Количество окон: {len(dataset)}")
has_positive = any(any(item[1] == 1) for item in dataset)
print(f"Есть ли кадры заставки (метка 1): {has_positive}")
positive_windows = sum(1 for item in dataset if any(item[1] == 1))
print(f"Количество окон с кадрами заставки: {positive_windows}")
total_positive_frames = sum(item[1].sum().item() for item in dataset)
print(f"Общее количество кадров с меткой 1: {total_positive_frames}")
print(f"Доля кадров с меткой 1: {total_positive_frames / (len(dataset) * 60):.4%}")
if len(dataset) > 0:
    print(f"Размерность эмбеддингов: {dataset[0][0].shape}")

Количество окон: 149
Есть ли кадры заставки (метка 1): True
Количество окон с кадрами заставки: 64
Общее количество кадров с меткой 1: 412.0
Доля кадров с меткой 1: 4.6085%
Размерность эмбеддингов: torch.Size([60, 768])


In [65]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Training on device: {device}")


d_model = 768
class_weights = torch.tensor([1.0, 2.0])  # Интро важнее
model = IntroDetectionTransformer(d_model=d_model, n_heads=12, n_layers=6, num_labels=2, class_weights=class_weights).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5, weight_decay=1e-5)

torch.nn.TransformerEncoder.use_nested_tensor = False

num_epochs = 30
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    for embeddings, labels in tqdm(dataloader, desc=f"Эпоха {epoch+1}/{num_epochs}"):
        embeddings, labels = embeddings.to(device), labels.to(device)
        mask = torch.ones(embeddings.shape[0], 60, device=device).bool()
        optimizer.zero_grad()
        loss = model(embeddings, labels, mask)
        if loss.dim() > 0:
            loss = loss.mean()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Эпоха {epoch+1}, Средняя потеря: {total_loss / len(dataloader):.4f}")

model_path = os.path.join('/kaggle/working/', "model.pt")
torch.save(model.state_dict(), model_path)
print(f"Модель сохранена в {model_path}")

Training on device: cuda


Эпоха 1/30: 100%|██████████| 10/10 [00:01<00:00,  9.02it/s]


Эпоха 1, Средняя потеря: 24.1686


Эпоха 2/30: 100%|██████████| 10/10 [00:01<00:00,  9.98it/s]


Эпоха 2, Средняя потеря: 12.5975


Эпоха 3/30: 100%|██████████| 10/10 [00:00<00:00, 10.08it/s]


Эпоха 3, Средняя потеря: 10.2740


Эпоха 4/30: 100%|██████████| 10/10 [00:00<00:00, 10.14it/s]


Эпоха 4, Средняя потеря: 9.4076


Эпоха 5/30: 100%|██████████| 10/10 [00:01<00:00,  9.82it/s]


Эпоха 5, Средняя потеря: 8.7030


Эпоха 6/30: 100%|██████████| 10/10 [00:01<00:00,  9.96it/s]


Эпоха 6, Средняя потеря: 8.0098


Эпоха 7/30: 100%|██████████| 10/10 [00:00<00:00, 10.00it/s]


Эпоха 7, Средняя потеря: 7.2284


Эпоха 8/30: 100%|██████████| 10/10 [00:01<00:00,  9.84it/s]


Эпоха 8, Средняя потеря: 6.7306


Эпоха 9/30: 100%|██████████| 10/10 [00:01<00:00,  9.91it/s]


Эпоха 9, Средняя потеря: 6.2217


Эпоха 10/30: 100%|██████████| 10/10 [00:01<00:00,  9.88it/s]


Эпоха 10, Средняя потеря: 5.6213


Эпоха 11/30: 100%|██████████| 10/10 [00:01<00:00,  9.95it/s]


Эпоха 11, Средняя потеря: 5.1102


Эпоха 12/30: 100%|██████████| 10/10 [00:01<00:00,  9.85it/s]


Эпоха 12, Средняя потеря: 5.4195


Эпоха 13/30: 100%|██████████| 10/10 [00:01<00:00,  9.73it/s]


Эпоха 13, Средняя потеря: 5.8463


Эпоха 14/30: 100%|██████████| 10/10 [00:01<00:00,  9.77it/s]


Эпоха 14, Средняя потеря: 7.2038


Эпоха 15/30: 100%|██████████| 10/10 [00:01<00:00,  9.51it/s]


Эпоха 15, Средняя потеря: 7.6556


Эпоха 16/30: 100%|██████████| 10/10 [00:01<00:00,  9.47it/s]


Эпоха 16, Средняя потеря: 4.9552


Эпоха 17/30: 100%|██████████| 10/10 [00:01<00:00,  9.72it/s]


Эпоха 17, Средняя потеря: 4.7068


Эпоха 18/30: 100%|██████████| 10/10 [00:01<00:00,  9.69it/s]


Эпоха 18, Средняя потеря: 3.6334


Эпоха 19/30: 100%|██████████| 10/10 [00:01<00:00,  9.61it/s]


Эпоха 19, Средняя потеря: 3.3308


Эпоха 20/30: 100%|██████████| 10/10 [00:01<00:00,  9.61it/s]


Эпоха 20, Средняя потеря: 3.0314


Эпоха 21/30: 100%|██████████| 10/10 [00:01<00:00,  9.42it/s]


Эпоха 21, Средняя потеря: 2.7904


Эпоха 22/30: 100%|██████████| 10/10 [00:01<00:00,  9.42it/s]


Эпоха 22, Средняя потеря: 2.9009


Эпоха 23/30: 100%|██████████| 10/10 [00:01<00:00,  9.41it/s]


Эпоха 23, Средняя потеря: 2.4651


Эпоха 24/30: 100%|██████████| 10/10 [00:01<00:00,  9.25it/s]


Эпоха 24, Средняя потеря: 2.3119


Эпоха 25/30: 100%|██████████| 10/10 [00:01<00:00,  9.41it/s]


Эпоха 25, Средняя потеря: 2.3643


Эпоха 26/30: 100%|██████████| 10/10 [00:01<00:00,  9.46it/s]


Эпоха 26, Средняя потеря: 2.5779


Эпоха 27/30: 100%|██████████| 10/10 [00:01<00:00,  9.38it/s]


Эпоха 27, Средняя потеря: 2.5489


Эпоха 28/30: 100%|██████████| 10/10 [00:01<00:00,  9.40it/s]


Эпоха 28, Средняя потеря: 1.8679


Эпоха 29/30: 100%|██████████| 10/10 [00:01<00:00,  9.40it/s]


Эпоха 29, Средняя потеря: 1.5295


Эпоха 30/30: 100%|██████████| 10/10 [00:01<00:00,  9.45it/s]


Эпоха 30, Средняя потеря: 2.4235
Модель сохранена в /kaggle/working/model.pt


In [66]:
class VideoDataset2(Dataset):
    def __init__(self, dataset_path):
        self.data = torch.load(dataset_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return item["embeddings"], item["labels"], item["video_id"], item["window_start_frame_idx"]

In [72]:
from scipy.ndimage import label as scipy_label
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
from scipy.ndimage import binary_closing


def calculate_iou(pred_start, pred_end, true_start, true_end):
    intersection_start = max(pred_start, true_start)
    intersection_end = min(pred_end, true_end)
    intersection = max(0, intersection_end - intersection_start)
    union = (pred_end - pred_start) + (true_end - true_start) - intersection
    return intersection / union if union > 0 else 0.0


test_dataset_path = '/kaggle/input/introdetectiondataset/test_dataset2.pt'
test_dataset = VideoDataset2(test_dataset_path)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

d_model = 768
#+ Передаем num_labels=2
model = IntroDetectionTransformer(d_model=d_model, n_heads=12, n_layers=6, num_labels=2, class_weights=class_weights).to(device)
model_path = '/kaggle/working/model.pt'
model.load_state_dict(torch.load(model_path))
model.eval()

#+ all_probs больше не существует, будут all_preds
all_preds_raw = [] # Для исходных предсказаний
all_preds_closed = [] # Для обработанных предсказаний
all_labels = []
video_predictions = {}
video_ground_truth = {}

with torch.no_grad():
    print("Исходные переходы CRF:\n", model.crf.transitions.data)
    
    boost_self_transition = 1.5  # Увеличиваем вероятность остаться в состоянии "заставка"
    penalize_exit_transition = -1.5 # Уменьшаем вероятность выйти из состояния "заставка"

    # Сохраняем оригинальные веса
    original_transitions = model.crf.transitions.data.clone()
    
    model.crf.transitions.data[1, 1] += boost_self_transition
    model.crf.transitions.data[1, 0] += penalize_exit_transition
    
    print("\nИзмененные переходы CRF:\n", model.crf.transitions.data)
    for embeddings, labels, video_ids, start_indices in tqdm(test_dataloader, desc="Оценка"):
        if embeddings.shape[0] <= 1:
            print(f"Пропущен батч размером {embeddings.shape[0]} для стабильности.")
            continue
        embeddings = embeddings.to(device)
        mask = torch.ones(embeddings.shape[0], 60, device=device, dtype=torch.bool)

        predicted_sequences = model(embeddings, mask=mask)
        
        predicted_labels = np.array(predicted_sequences)
        true_labels = labels.cpu().numpy()
        closed_labels_batch = []
        for pred_seq in predicted_labels:
            closed_labels_batch.append(binary_closing(pred_seq, structure=np.ones(3)).astype(int))
        closed_labels_batch = np.array(closed_labels_batch)
        
        all_preds_raw.append(predicted_labels)
        all_preds_closed.append(closed_labels_batch)
        all_labels.append(true_labels)
        
        for i, (video_id, start_idx, pred_seq_closed, true_seq) in enumerate(zip(video_ids, start_indices, closed_labels_batch, true_labels)):
            start_idx = start_idx.item()
            if video_id not in video_predictions:
                video_predictions[video_id] = []
                video_ground_truth[video_id] = []

            labeled_preds, num_groups = scipy_label(pred_seq_closed)
            for group in range(1, num_groups + 1):
                indices = np.where(labeled_preds == group)[0]
                if len(indices) >= 3:
                    start_sec = start_idx + indices[0]
                    end_sec = start_idx + indices[-1] + 1
                    video_predictions[video_id].append((start_sec, end_sec))

            labeled_trues, num_true_groups = scipy_label(true_seq)
            for group in range(1, num_true_groups + 1):
                indices = np.where(labeled_trues == group)[0]
                if indices.size > 0:
                    start_sec = start_idx + indices[0]
                    end_sec = start_idx + indices[-1] + 1
                    if (start_sec, end_sec) not in video_ground_truth.get(video_id, []):
                         video_ground_truth[video_id].append((start_sec, end_sec))
    model.crf.transitions.data = original_transitions


all_preds_raw_flat = np.concatenate(all_preds_raw).flatten()
all_labels_flat = np.concatenate(all_labels).flatten()

print("\n--- Метрики ДО постобработки ---")
print(f"Precision: {precision_score(all_labels_flat, all_preds_raw_flat, zero_division=0):.4f}")
print(f"Recall:    {recall_score(all_labels_flat, all_preds_raw_flat, zero_division=0):.4f}")
print(f"F1-мера:   {f1_score(all_labels_flat, all_preds_raw_flat, zero_division=0):.4f}")


all_preds_closed_flat = np.concatenate(all_preds_closed).flatten()

print("\n--- Метрики ПОСЛЕ постобработки (binary_closing) ---")
print(f"Precision: {precision_score(all_labels_flat, all_preds_closed_flat, zero_division=0):.4f}")
print(f"Recall:    {recall_score(all_labels_flat, all_preds_closed_flat, zero_division=0):.4f}")
print(f"F1-мера:   {f1_score(all_labels_flat, all_preds_closed_flat, zero_division=0):.4f}")
 
iou_scores = []
for video_id in video_predictions:
    pred_intervals = video_predictions.get(video_id, [])
    true_intervals = video_ground_truth.get(video_id, [])
    if not true_intervals:
        continue
    for pred in pred_intervals:
        for true in true_intervals:
            iou = calculate_iou(pred[0], pred[1], true[0], true[1])
            iou_scores.append(iou)

mean_iou = np.mean(iou_scores) if iou_scores else 0.0
print("\nМетрики на уровне интервалов:")
print(f"Средний IoU: {mean_iou:.4f}")

Исходные переходы CRF:
 tensor([[ 0.0791,  0.0579],
        [-0.0193,  0.0196]], device='cuda:0')

Измененные переходы CRF:
 tensor([[ 0.0791,  0.0579],
        [-1.5193,  1.5196]], device='cuda:0')


Оценка: 100%|██████████| 3/3 [00:00<00:00, 31.78it/s]


--- Метрики ДО постобработки ---
Precision: 0.6536
Recall:    0.7752
F1-мера:   0.7092

--- Метрики ПОСЛЕ постобработки (binary_closing) ---
Precision: 0.6586
Recall:    0.7726
F1-мера:   0.7111

Метрики на уровне интервалов:
Средний IoU: 0.6295





In [73]:
%cd /kaggle/working
from IPython.display import FileLink
FileLink('model.pt')

/kaggle/working
