In [1]:
# import
!pip install jiwer

Defaulting to user installation because normal site-packages is not writeable


In [2]:
# import
import os
import torch
import torch.nn as nn
from jiwer import wer
import soundfile as sf
from datasets import load_dataset,load_from_disk
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class DataParallelCriterion(nn.DataParallel):
    def forward(self, inputs, *targets, **kwargs):
        targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
        targets = tuple(targets_per_gpu[0] for targets_per_gpu in targets)
        outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
        return nn.Reduce.apply(*outputs) / len(outputs), targets


def _criterion_parallel_apply(replicas, inputs, targets, kwargs):
    return [replica(inp, targ, **kwarg) for replica, inp, targ, kwarg in zip(replicas, inputs, targets, kwargs)]



In [None]:
# 에러 로깅 가능
os.environ['CUDA_LAUNCH_BLOCKING'] = '3'

# cuda가 볼 수 잇는 GPU => 내가 사용할 GPU
# os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1, 2, 3, 4, 5, 6, 7'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model_name = "facebook/wav2vec2-base-960h"
# model_name="facebook/wav2vec2-large-robust-ft-libri-960h"
model_name="facebook/wav2vec2-base"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)

if torch.cuda.is_available():
    os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1, 2, 3, 4, 5, 6, 7'
    model = nn.DataParallel(model, output_device=1)
ds = load_from_disk("./data/datasets")

print(ds)
test_ds = ds['test']

def map_to_array(batch):
    speech, _ = sf.read(batch["path"])
    batch["speech"] = speech
    return batch

test_ds = test_ds.map(map_to_array)

def map_to_pred(batch):
    inputs = processor(batch["speech"], sampling_rate=16000, return_tensors="pt", padding="longest")
    input_values = inputs.input_values.to(device)
    
    model.eval()
    
    with torch.no_grad():
        logits = model(input_values).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    batch["result"] = transcription

    return batch

result = test_ds.map(map_to_pred, batched=True, batch_size=16)

print(len(result))
# print(type(result))

# total  = len(result)
# correct = 0
# for data in result:
#     if data["transcription"] == data["result"]:
#         correct+=1

# accuracy = (correct/total)*100
print("WER:", wer(result["transcription"], result["result"]))
# accuracy

In [5]:
# # 에러 로깅 가능
# os.environ['CUDA_LAUNCH_BLOCKING'] = '3'

# # cuda가 볼 수 잇는 GPU => 내가 사용할 GPU
# os.environ["CUDA_VISIBLE_DEVICES"] = '4,5,6,7'
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# dataset_name= "test1"
# model_name = "facebook/wav2vec2-base-960h"
# dataset_dir ="./data/exist_test/dict2"

# processor = Wav2Vec2Processor.from_pretrained(model_name)
# model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)

# eval_prev_train_model(model_name,dataset_dir,device,dataset_name)