In [4]:
from datasets import load_from_disk
from dataclasses import dataclass, field
import torch

dataset = load_from_disk("/teamspace/studios/this_studio/datasets/fleurs-filtered-proc")

In [5]:
from transformers import Wav2Vec2BertProcessor, Wav2Vec2BertForCTC
processor = Wav2Vec2BertProcessor.from_pretrained("/teamspace/studios/this_studio/models/facebook/w2v-bert-2.0-finetuned")
model= Wav2Vec2BertForCTC.from_pretrained("/teamspace/studios/this_studio/models/facebook/w2v-bert-2.0")

In [7]:
from accelerate import Accelerator, DistributedType
accelerator = Accelerator(mixed_precision="no")

In [9]:

from utils import DataCollatorCTCWithPadding
data_collator = DataCollatorCTCWithPadding(processor=processor, input_key='input_features', padding=True)


In [10]:
from torch.utils.data.dataloader import DataLoader

dataloaders = {
    'train': DataLoader(dataset['train'], batch_size=8, collate_fn=data_collator),
    'test': DataLoader(dataset['test'], batch_size=8, collate_fn=data_collator)
}

In [11]:
len(dataloaders["train"])

92

In [14]:

# Assume dataloaders is a dictionary with 'train' key for train_dataloader
dataloader = dataloaders["train"]
batch_size = 8

# Function to check the data types of a batch

def get_batch(data_loader, batch_idx):
    """
    Fetches a specific batch by index from a DataLoader using an iterator.
    
    Args:
    data_loader (torch.utils.data.DataLoader): The DataLoader.
    batch_idx (int): Index of the batch to fetch.
    
    Returns:
    Any: The batch at the specified index.
    """
    batch_iter = iter(data_loader)
    for _ in range(batch_idx):
        batch = next(batch_iter)
    return batch

def check_batch(batch):
    """
    Check the data types and other properties of each element in the batch.
    Handles dictionary-based batch structures.
    
    Args:
    batch (dict): A batch that may contain various elements keyed by descriptive names.
    """
    for key, item in batch.items():
        print(f"Key: {key}")
        print("  Type:", item.dtype)
        print("  Shape:", item.shape)
        if torch.isnan(item).any() or torch.isinf(item).any():
            print("  Contains NaNs or Infs.")
        else:
            print("  No NaNs or Infs detected.")


def compare_batches(data_loader):
    """
    Compares the 158th and 159th batch of the DataLoader.
    
    Args:
    data_loader (torch.utils.data.DataLoader): The DataLoader to inspect.
    """
    batch_156 = get_batch(data_loader, 75)
    # batch_157 = get_batch(data_loader, 76)
    # batch_158 = get_batch(data_loader, 77)
    # batch_159 = get_batch(data_loader, 78)
    
    print("Checking Batch 76:")
    check_batch(batch_156)

    # print("Checking Batch 76:")
    # check_batch(batch_157)
    
    # print("Checking Batch 77:")
    # check_batch(batch_158)
    
    # print("\nChecking Batch 78:")
    # check_batch(batch_159)

# Example usage (you'll need to replace `your_data_loader` with your actual DataLoader)
# compare_batches(your_data_loader)


compare_batches(dataloader)



Checking Batch 76:
Key: input_features
  Type: torch.float32
  Shape: torch.Size([8, 392, 160])
  No NaNs or Infs detected.
Key: attention_mask
  Type: torch.int32
  Shape: torch.Size([8, 392])
  No NaNs or Infs detected.
Key: labels
  Type: torch.int64
  Shape: torch.Size([8, 69])
  No NaNs or Infs detected.


In [42]:
batch_number = 78
start_index = (batch_number - 1) * batch_size
end_index = start_index + batch_size

# Slice the dataset to get the 77th batch
sub_dataset = dataset['train'].select(range(start_index, end_index))


In [43]:
for i in range(8):
    print(len(sub_dataset["input_features"][i]))
    
    print("label: ", len(sub_dataset["labels"][i]))

224
label:  31
344
label:  80


359
label:  86
326
label:  74
356
label:  78
278
label:  56
257
label:  64
197
label:  47
