In [1]:
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

MODEL_ID_FRENCH = "jonatasgrosman/wav2vec2-large-xlsr-53-french"

MODEL_ID_GERMAN = "jonatasgrosman/wav2vec2-xls-r-1b-german"

MODEL_ID_ITALIAN = "dbdmg/wav2vec2-xls-r-300m-italian-robust"

MODEL_ID_JAPNESE = "jonatasgrosman/wav2vec2-large-xlsr-53-japanese"

In [4]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 

In [5]:
device

'cuda:0'

In [16]:
def multilingual_pretrained_asr(MODEL_ID, input_audio_path, device):
    
    wav, sr = torchaudio.load(input_audio_path)
    
    processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
    model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
    model = model.to(device)
    
    inputs = processor(wav, sampling_rate = 16_000, return_tensors = "pt", padding = True)
    
    inputs = inputs.to(device)
    
    with torch.no_grad():
        
        input_values = inputs['input_values'].squeeze(0)
        
        logits = model(input_values, attention_mask = inputs.attention_mask).logits
    
    predicted_ids = torch.argmax(logits, dim = -1)
    
    predicted_sentences = processor.batch_decode(predicted_ids)

    return predicted_sentences

## Loading the french audio

In [17]:
french_audio_path = "/home/IITB/ai-at-ieor/23m1508/Shivam_23M1508/Interspeech/data/audio_set_2_chunks/french/chunk_1/0021-0021_001_phone-O1-000026-000603.wav"

In [None]:
"0021-0021_001_phone-O1-000026-000603"

'0021-0021_001_phone-O1-000026-000603'

In [18]:
original_text = "Collection de voix du projet Z Y deux-mille-vingt-et-un zéro sept trente, je suis l'enregistreur un."

In [19]:
multilingual_pretrained_asr(MODEL_ID_FRENCH, french_audio_path, device = device)

['collections de voie du projet zady grec deux mille vingt-et unz zéro sept trente je suis lenregistrern']

## Loading the german audio

In [21]:
german_audio_path = "/home/IITB/ai-at-ieor/23m1508/Shivam_23M1508/Interspeech/data/audio_set_2_chunks/german/chunk_1/0021-0021_001_phone-O1-002140-002337.wav"

original_text = "Ja, Schatz, was essen wir denn heute Abend?"

In [22]:
multilingual_pretrained_asr(MODEL_ID_GERMAN, german_audio_path, device = device)

['ja schatz was esst man denn heute abend']

## Loading the italian audio

In [24]:
italian_audio_path = "/home/IITB/ai-at-ieor/23m1508/Shivam_23M1508/Interspeech/data/audio_set_2_chunks/italian/chunk_1/0002-0002_003_phone-O1-002346-002637.wav"

original_text = "OK, ci ci sei Andrea?"

In [25]:
multilingual_pretrained_asr(MODEL_ID_ITALIAN, italian_audio_path, device = device)

['okai ci sai andrea']

## Loading the japnese audio

### Checking the embedding size of facebook/wav2vec2-xls-r-1b model

In [7]:
import json
import csv
import re
import os
import torch
import random
import torchaudio
import torch.nn as nn
import numpy as np
import pandas as pd
from jiwer import wer
from tqdm.auto import tqdm
import torch.nn.functional as F
from IPython.display import Audio
from functools import partial
from torch.utils.data import Dataset

from transformers import Wav2Vec2Model
import torch.nn as nn
from torch.utils.data import DataLoader
# AdamW is best optimizer
from torch.optim import AdamW
from transformers import get_scheduler
from transformers import Wav2Vec2Model,Wav2Vec2Processor, Wav2Vec2FeatureExtractor

In [2]:
model_path = "facebook/wav2vec2-xls-r-1b"

In [5]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cuda:0'

In [4]:
model = Wav2Vec2Model.from_pretrained(model_path)

In [6]:
model = model.to(device)

In [8]:
feature_extractor = Wav2Vec2FeatureExtractor(feature_size = 1, 
                                             sampling_rate = 16000, 
                                             padding_value = 0.0, 
                                             do_normalize = True, 
                                             return_attention_mask = True)

In [11]:
waveform, sr = torchaudio.load('/scratch/IITB/ai-at-ieor/23m1508/23m1508_backup/audio_set_2_chunks/german/chunk_1/0021-0021_001_phone-O1-002140-002337.wav')

In [12]:
waveform

tensor([[ 0.0010,  0.0015,  0.0016,  ..., -0.0055, -0.0059, -0.0067]])

In [None]:
input_values = feature_extractor(
                                waveform, sampling_rate=16000,return_tensors="pt"
                                )["input_values"].squeeze(0)

In [16]:
input_values = input_values.to(device)

In [17]:
out = model(input_values)

In [24]:
out['last_hidden_state'].shape

torch.Size([1, 98, 1280])

In [1]:
import torch

In [2]:
checkpoint_path = "/scratch/IITB/ai-at-ieor/23m1508/23m1508_backup/checkpoints_mtp/w_till_epoch_1/w_till_epoch_2/multilingual_asr_model_1130000_1150000.pt"

In [3]:
checkpoint = torch.load(checkpoint_path, map_location = 'cpu', weights_only = False)

In [None]:
from transformers import Wav2Vec2Model

In [None]:
# Loaidng the pretrained model
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-2b")

In [31]:
state_dict_keys = list(model.state_dict().keys())

In [32]:
state_dict_keys

['masked_spec_embed',
 'feature_extractor.conv_layers.0.conv.weight',
 'feature_extractor.conv_layers.0.conv.bias',
 'feature_extractor.conv_layers.0.layer_norm.weight',
 'feature_extractor.conv_layers.0.layer_norm.bias',
 'feature_extractor.conv_layers.1.conv.weight',
 'feature_extractor.conv_layers.1.conv.bias',
 'feature_extractor.conv_layers.1.layer_norm.weight',
 'feature_extractor.conv_layers.1.layer_norm.bias',
 'feature_extractor.conv_layers.2.conv.weight',
 'feature_extractor.conv_layers.2.conv.bias',
 'feature_extractor.conv_layers.2.layer_norm.weight',
 'feature_extractor.conv_layers.2.layer_norm.bias',
 'feature_extractor.conv_layers.3.conv.weight',
 'feature_extractor.conv_layers.3.conv.bias',
 'feature_extractor.conv_layers.3.layer_norm.weight',
 'feature_extractor.conv_layers.3.layer_norm.bias',
 'feature_extractor.conv_layers.4.conv.weight',
 'feature_extractor.conv_layers.4.conv.bias',
 'feature_extractor.conv_layers.4.layer_norm.weight',
 'feature_extractor.conv_layer

In [None]:
# Use only one time since in this i have extatct the first 48 layers weights
state_dict = checkpoint['model_state_dict']  # assuming it's saved like this

In [24]:
# state_dict.keys()

In [18]:
# Strip off "projector.wav2vec2." prefix
base_model_state_dict = {
k.replace("projector.wav2vec2.", ""): v
for k, v in state_dict.items()
if k.startswith("projector.wav2vec2.")
}

In [33]:
base_model_state_dict.keys()

dict_keys(['masked_spec_embed', 'feature_extractor.conv_layers.0.conv.weight', 'feature_extractor.conv_layers.0.conv.bias', 'feature_extractor.conv_layers.0.layer_norm.weight', 'feature_extractor.conv_layers.0.layer_norm.bias', 'feature_extractor.conv_layers.1.conv.weight', 'feature_extractor.conv_layers.1.conv.bias', 'feature_extractor.conv_layers.1.layer_norm.weight', 'feature_extractor.conv_layers.1.layer_norm.bias', 'feature_extractor.conv_layers.2.conv.weight', 'feature_extractor.conv_layers.2.conv.bias', 'feature_extractor.conv_layers.2.layer_norm.weight', 'feature_extractor.conv_layers.2.layer_norm.bias', 'feature_extractor.conv_layers.3.conv.weight', 'feature_extractor.conv_layers.3.conv.bias', 'feature_extractor.conv_layers.3.layer_norm.weight', 'feature_extractor.conv_layers.3.layer_norm.bias', 'feature_extractor.conv_layers.4.conv.weight', 'feature_extractor.conv_layers.4.conv.bias', 'feature_extractor.conv_layers.4.layer_norm.weight', 'feature_extractor.conv_layers.4.layer_