In [11]:
from datasets import load_from_disk
from dataclasses import dataclass, field
import torch

dataset = load_from_disk("/teamspace/studios/this_studio/FinalProject/wav2vec_finetune/wav2vec2bert/datasets/fleurs-filtered-proc")

In [14]:
from transformers import Wav2Vec2BertProcessor
processor = Wav2Vec2BertProcessor.from_pretrained("models/facebook/w2v-bert-2.0-finetuned")


In [15]:
from accelerate import Accelerator, DistributedType
accelerator = Accelerator(mixed_precision="no")


In [17]:
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    def __init__(self, processor, accelerator, input_key='input_features', padding=True):
        self.processor = processor
        self.accelerator= accelerator
        self.input_key = input_key
        self.padding = padding

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        if self.input_key == 'input_features':
            input_features = [{"input_features": feature["input_features"]} for feature in features]
        elif self.input_key == 'input_values':
            input_features = [{"input_values": feature["input_values"]} for feature in features]

        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # Determine pad_to_multiple_of based on accelerator's mixed precision
        if self.accelerator.distributed_type == DistributedType.XLA:
            max_length = 128
        else:
            max_length = None

        if self.accelerator.mixed_precision == "fp8":
            pad_to_multiple_of = 16
        elif self.accelerator.mixed_precision != "no":
            pad_to_multiple_of = 8
        else:
            pad_to_multiple_of = None

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=max_length,
            pad_to_multiple_of=pad_to_multiple_of,
            return_tensors="pt",
        )

        labels_batch = self.processor.pad(
            labels=label_features,
            padding=self.padding,
            max_length=max_length,
            pad_to_multiple_of=pad_to_multiple_of,
            return_tensors="pt",
        )

        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels

        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, accelerator=accelerator, input_key='input_features', padding=True)


In [19]:
from torch.utils.data.dataloader import DataLoader

dataloaders = {
    'train': DataLoader(dataset['train'], batch_size=4, collate_fn=data_collator),
    'test': DataLoader(dataset['test'], batch_size=4, collate_fn=data_collator)
}

In [26]:
len(dataloaders["train"])

184

In [39]:

# Assume dataloaders is a dictionary with 'train' key for train_dataloader
dataloader = dataloaders["train"]
batch_size = 4

# Function to check the data types of a batch

def get_batch(data_loader, batch_idx):
    """
    Fetches a specific batch by index from a DataLoader using an iterator.
    
    Args:
    data_loader (torch.utils.data.DataLoader): The DataLoader.
    batch_idx (int): Index of the batch to fetch.
    
    Returns:
    Any: The batch at the specified index.
    """
    batch_iter = iter(data_loader)
    for _ in range(batch_idx):
        batch = next(batch_iter)
    return batch

def check_batch(batch):
    """
    Check the data types and other properties of each element in the batch.
    Handles dictionary-based batch structures.
    
    Args:
    batch (dict): A batch that may contain various elements keyed by descriptive names.
    """
    for key, item in batch.items():
        print(f"Key: {key}")
        print("  Type:", item.dtype)
        print("  Shape:", item.shape)
        if torch.isnan(item).any() or torch.isinf(item).any():
            print("  Contains NaNs or Infs.")
        else:
            print("  No NaNs or Infs detected.")


def compare_batches(data_loader):
    """
    Compares the 158th and 159th batch of the DataLoader.
    
    Args:
    data_loader (torch.utils.data.DataLoader): The DataLoader to inspect.
    """
    batch_157 = get_batch(data_loader, 157)
    batch_158 = get_batch(data_loader, 158)
    batch_159 = get_batch(data_loader, 159)
    
    print("Checking Batch 157:")
    check_batch(batch_157)
    
    print("Checking Batch 158:")
    check_batch(batch_158)
    
    print("\nChecking Batch 159:")
    check_batch(batch_159)

# Example usage (you'll need to replace `your_data_loader` with your actual DataLoader)
# compare_batches(your_data_loader)


compare_batches(dataloader)



Checking Batch 157:
Key: input_features
  Type: torch.float32
  Shape: torch.Size([4, 398, 160])
  No NaNs or Infs detected.
Key: attention_mask
  Type: torch.int32
  Shape: torch.Size([4, 398])
  No NaNs or Infs detected.
Key: labels
  Type: torch.int64
  Shape: torch.Size([4, 108])
  No NaNs or Infs detected.
Checking Batch 158:
Key: input_features
  Type: torch.float32
  Shape: torch.Size([4, 395, 160])
  No NaNs or Infs detected.
Key: attention_mask
  Type: torch.int32
  Shape: torch.Size([4, 395])
  No NaNs or Infs detected.
Key: labels
  Type: torch.int64
  Shape: torch.Size([4, 123])
  No NaNs or Infs detected.

Checking Batch 159:
Key: input_features
  Type: torch.float32
  Shape: torch.Size([4, 386, 160])
  No NaNs or Infs detected.
Key: attention_mask
  Type: torch.int32
  Shape: torch.Size([4, 386])
  No NaNs or Infs detected.
Key: labels
  Type: torch.int64
  Shape: torch.Size([4, 101])
  No NaNs or Infs detected.
