In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from src.model_def1 import *
from src.gest_segm import *
from src.dataset import *

# Directory containing the RTTM files
directory = 'annot_prepare/results/rttm'
segments = get_segments_from_directory(directory)

# Generate dummy data
batch_size = 16
seq_len = 2500
input_dim = 257 * 3  # Number of channels (LPS, IPD, AF) concatenated
num_speakers = 4  # Fixed to 4
num_channels = 512
bottleneck_size = 256
kernel_size = 3
frame_size = 0.02  # Placeholder for actual frame size in seconds
num_frames = seq_len  # Assuming 1 frame per time step
n_seconds = int(num_frames * frame_size)

# Filter segments before n_seconds
def filter_segments_before_n_seconds(segments, n_seconds):
    return [seg for seg in segments if seg[2] + seg[3] <= n_seconds]

filtered_segments = filter_segments_before_n_seconds(segments, n_seconds)

# Generate dummy features
lps = torch.randn(batch_size, 257, seq_len)
ipd = torch.randn(batch_size, 257, seq_len)
af = torch.randn(batch_size, 257, seq_len)

# Concatenate features
features = torch.cat((lps, ipd, af), dim=1)

# Generate VAD and OSD labels
labels_vad = segments_to_vad_labels(segments, num_speakers, num_frames, frame_size)
labels_osd = segments_to_osd_labels(segments, num_frames, frame_size)

# Adjust labels to match the batch size
labels_vad = torch.tensor(labels_vad)  # Shape: [num_speakers, num_frames]
labels_osd = torch.tensor(labels_osd)  # Shape: [num_frames]

# Expand labels to match the batch size
labels_vad = labels_vad.unsqueeze(1).expand(-1, batch_size, -1)  # Shape: [num_speakers, batch_size, num_frames]
labels_osd = labels_osd.unsqueeze(0).expand(batch_size, -1)      # Shape: [batch_size, num_frames]



# Create dataset and dataloader
dataset = DiarizationDataset(features, labels_vad, labels_osd)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define ConvTasNet model
class ConvTasNet(nn.Module):
    def __init__(self, input_dim, bottleneck_size=256, num_channels=512, kernel_size=3, num_speakers=4):
        super(ConvTasNet, self).__init__()
        self.tcn = TemporalConvNet(input_dim, [num_channels]*8, kernel_size)
        self.bottleneck = nn.Conv1d(num_channels, bottleneck_size, 1)
        self.decoders_vad = nn.ConvTranspose1d(bottleneck_size, num_speakers, kernel_size=2, stride=2)
        self.decoders_osd = nn.ConvTranspose1d(bottleneck_size, 1, kernel_size=2, stride=2)
    
    def forward(self, lps, ipd, af):
        # Concatenate inputs along the channel dimension
        x = torch.cat((lps, ipd, af), dim=1)
        
        # TCN blocks
        x = self.tcn(x)
        
        # Bottleneck layer
        x = self.bottleneck(x)
        
        # Decoder layers
        y_vad = self.decoders_vad(x)
        y_osd = self.decoders_osd(x)
        
        # Reshape outputs
        y_vad = y_vad.permute(0, 2, 1)  # Shape: [batch_size, num_frames, num_speakers]
        y_osd = y_osd.squeeze(1)  # Shape: [batch_size, num_frames]
        
        return y_vad, y_osd

# Initialize model
conv_tasnet = ConvTasNet(input_dim=input_dim, num_speakers=num_speakers)

# Define loss functions and optimizer
criterion_vad = nn.BCELoss()
criterion_osd = nn.BCELoss()
optimizer = optim.Adam(conv_tasnet.parameters(), lr=1e-4)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    conv_tasnet.train()
    running_loss_vad = 0.0
    running_loss_osd = 0.0
    
    for i, (inputs, vad_labels, osd_labels) in enumerate(dataloader):
        optimizer.zero_grad()
        
        # Forward pass
        vad_outputs, osd_outputs = conv_tasnet(inputs[:, :257, :], inputs[:, 257:514, :], inputs[:, 514:, :])
        
        # Compute losses
        vad_labels = vad_labels.permute(1, 2, 0)  # Shape: [batch_size, num_speakers, num_frames]
        loss_vad = criterion_vad(vad_outputs, vad_labels.float())
        loss_osd = criterion_osd(osd_outputs, osd_labels.float())
        loss = loss_vad + loss_osd
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        running_loss_vad += loss_vad.item()
        running_loss_osd += loss_osd.item()
        
        if i % 10 == 9:  # Print every 10 batches
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(dataloader)}], "
                  f"Loss VAD: {running_loss_vad/10:.4f}, Loss OSD: {running_loss_osd/10:.4f}")
            running_loss_vad = 0.0
            running_loss_osd = 0.0

print('Finished Training')

ValueError: Using a target size (torch.Size([4, 2500, 16])) that is different to the input size (torch.Size([16, 5000, 4])) is deprecated. Please ensure they have the same size.