In [1]:
# Hugging Face Datasets
from datasets import load_dataset, concatenate_datasets

# Data processing and metrics
import numpy as np
from sklearn.metrics import (
    accuracy_score,
    recall_score,
    f1_score,
    confusion_matrix,
)

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, clear_output


dataset_name = 'cairocode/IEMO_WAV_002'
dataset = load_dataset(dataset_name)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset['train'][0]

{'audio': {'path': 'Ses01F_impro01_F000.wav',
  'array': array([-0.0050354 , -0.00497437, -0.0038147 , ..., -0.00265503,
         -0.00317383, -0.00418091]),
  'sampling_rate': 16000},
 'label': 0,
 'valence': 2.5,
 'arousal': 2.5,
 'domination': 2.5,
 'arousal_norm': 3.75,
 'valence_norm': -1.25,
 'speakerID': 2,
 'utterance_id': 'Ses01F_impro01_F000',
 'transcript': 'Excuse me.',
 'speaker_id': 2}

: 

In [None]:
# Hugging Face Datasets
from datasets import load_dataset, concatenate_datasets

# Data processing and metrics
import numpy as np
from sklearn.metrics import (
    accuracy_score,
    recall_score,
    f1_score,
    confusion_matrix,
)

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, clear_output


dataset_name = 'cairocode/IEMO_WAV_002'
dataset = load_dataset(dataset_name)
from datasets import DatasetDict
from PIL import Image
import matplotlib.pyplot as plt
import io
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import logging
from transformers import logging as transformers_logging
# Set the logging level to ERROR to suppress warnings
logging.getLogger("transformers").setLevel(logging.ERROR)
transformers_logging.set_verbosity_error()


# Updated audio_to_image function to return an image
def audio_to_image(audio_array, sample_rate=16000, model_name="facebook/wav2vec2-base-960h"):
    """
    Takes an input audio array and uses Wav2Vec2 to generate an image representation.

    Args:
        audio_array (numpy array): Input audio data as a 1D array.
        sample_rate (int): Sampling rate of the audio.
        model_name (str): Pretrained Wav2Vec2 model name.

    Returns:
        PIL.Image: Image representation of the audio features.
    """
    processor = Wav2Vec2Processor.from_pretrained(model_name)
    model = Wav2Vec2Model.from_pretrained(model_name)

    inputs = processor(audio_array, sampling_rate=sample_rate, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
 
    hidden_states = outputs.last_hidden_state.squeeze(0).numpy()

    plt.figure(figsize=(10, 6))
    plt.imshow(hidden_states.T, aspect="auto", origin="lower", cmap="viridis")
    plt.axis('off')
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

    buf = io.BytesIO()
    plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
    plt.close()
    buf.seek(0)

    image = Image.open(buf).convert('RGB')
    image = image.resize((224, 224))  # Resize to 224x224
    buf.close()
    return image

# Function to process a batch
def process_batch(batch, sample_rate=16000, model_name="facebook/wav2vec2-base-960h"):
    images = []
    for audio_data in tqdm(batch['audio'], desc="Processing batch"):
        image = audio_to_image(audio_data['array'], sample_rate=sample_rate, model_name=model_name)
        images.append(image)
    return {'image': images}


# Map the function to the dataset in batches
new_dataset = dataset.map(
    process_batch,
    batched=True,
    batch_size = 64,
    remove_columns=['audio'],
)

# Save the updated dataset
new_dataset.push_to_hub('IEMO_Wav2Vec2')


Processing batch: 100%|██████████| 64/64 [01:10<00:00,  1.11s/it]
Processing batch: 100%|██████████| 64/64 [01:12<00:00,  1.13s/it]
Processing batch: 100%|██████████| 64/64 [01:14<00:00,  1.16s/it]
Processing batch: 100%|██████████| 64/64 [01:29<00:00,  1.39s/it]
Processing batch: 100%|██████████| 64/64 [01:14<00:00,  1.16s/it]
Processing batch: 100%|██████████| 64/64 [01:10<00:00,  1.10s/it]
Processing batch: 100%|██████████| 64/64 [01:18<00:00,  1.23s/it]
Processing batch: 100%|██████████| 64/64 [01:11<00:00,  1.12s/it]
Processing batch: 100%|██████████| 64/64 [01:10<00:00,  1.10s/it]
Processing batch: 100%|██████████| 64/64 [01:15<00:00,  1.17s/it]
Processing batch: 100%|██████████| 64/64 [01:12<00:00,  1.13s/it]
Processing batch: 100%|██████████| 64/64 [01:14<00:00,  1.16s/it]
Processing batch: 100%|██████████| 64/64 [01:10<00:00,  1.11s/it]
Processing batch: 100%|██████████| 64/64 [01:09<00:00,  1.08s/it]
Processing batch: 100%|██████████| 64/64 [01:16<00:00,  1.20s/it]
Processing

In [None]:
new_dataset