In [63]:
!pip install datasets
!pip install audiomentations
!pip install transformers==4.11.3
!pip install librosa
!pip install jiwer



In [64]:
import torchaudio
import torch
import os
import matplotlib.pyplot as plt
import ipywidgets
import soundfile as sf
import torch
import datasets
import librosa
import random
import tensorflow as tf
import pandas as pd
import numpy as np
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift, AddBackgroundNoise,Gain
from typing import Tuple
from torch import Tensor
from torch.utils.data import Dataset
from IPython.display import Audio, display,HTML
from ipywidgets import IntProgress
from transformers import Wav2Vec2CTCTokenizer,Wav2Vec2ForCTC,Wav2Vec2Processor,Trainer,TrainingArguments,Wav2Vec2FeatureExtractor
from datasets import load_dataset, load_metric,Dataset,concatenate_datasets,set_caching_enabled, ClassLabel
from sklearn.model_selection import train_test_split

## AUGMENTING THE DATA

##### First thing first, import the timit dataset from the hugging face hub

In [65]:
timit = load_dataset("timit_asr")
#timit = load_dataset("timit_asr")

  0%|          | 0/2 [00:00<?, ?it/s]

##### We ahve a DatasetDict object that looks like follows

In [66]:
timit

DatasetDict({
    train: Dataset({
        features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'],
        num_rows: 4620
    })
    test: Dataset({
        features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'],
        num_rows: 1680
    })
})

##### Add some noyse to the data as follow:

In [67]:
waveform, sample_rate = torchaudio.load(timit['train']['audio'][5]['path'])

augment = Compose([
    AddGaussianNoise(min_amplitude=0.0003, max_amplitude=0.0025, p=0.3),
    #AddBackgroundNoise(min_snr_in_db=3,max_snr_in_db=30, noise_rms="relative", p=0.2),
    Gain(min_gain_in_db=-15.0,max_gain_in_db=5.0, p=0.4),
    TimeStretch(min_rate=0.8, max_rate=1.25, p=0.4),
    #PitchShift(min_semitones=-4, max_semitones=4, p=0.3),
    #Shift(min_fraction=-0.8, max_fraction=0.8, p=0.1),
])


# Augment/transform/perturb the audio data
augmented_samples = augment(samples=np.array(waveform), sample_rate=16000)


In [68]:
print(timit['train'][5]["text"])
Audio(data=np.asarray(augmented_samples[0]), autoplay=True, rate=16000)

Ambidextrous pickpockets accomplish more.


In [69]:
print(timit['train'][5]["text"])
Audio(data=np.asarray(timit['train'][5]["audio"]["array"]), autoplay=True, rate=16000)

Ambidextrous pickpockets accomplish more.


##### AUGMENT THE TRAINING AND DATA and save all the relevant infomrmation (waveform, samplrate, text) in a dictionary: trnasform a DatasetDict to a dict object because it is faster to work with. Will do the opposite later (dict to DatasetDict)

In [75]:
def dictionary_audio(dataset, augment):
    augment = Compose([
    AddGaussianNoise(min_amplitude=0.0003, max_amplitude=0.0025, p=0.7),
    Gain(min_gain_in_db=-15.0,max_gain_in_db=5.0, p=0.5),
    TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5),
    ])

    dict_data = [{}] * len(dataset)

    for i in range(len(dataset)):
        waveform, sample_rate = torchaudio.load(dataset[i]['file'])
        if augment:
            waveform = augment(samples=np.array(waveform[0]),
                                sample_rate=16000)
        

        new_item = {'audio': {
            'wave': np.array(waveform),
            'sample_rate': sample_rate}
            }

        for k, v in dataset[i].items():
            if k != 'audio':
                new_item[k] = v

            dict_data[i] = new_item
    
    return dict_data

        

In [76]:
augmented_data = dictionary_audio(timit['train'], augment=True)#dirty train

In [77]:
test_data = dictionary_audio(timit['test'], augment=True)#dirdy test 

In [82]:
audio = augmented_data[16]['audio']['wave']
text = augmented_data[16]['text']
print(text)
Audio(data=np.asarray(audio), autoplay=True, rate=16000)

He picked up nine pairs of socks for each brother.


In [80]:
audio = test_data[12]['audio']['wave']
text = test_data[12]['text']
print(text)
Audio(data=np.asarray(audio), autoplay=True, rate=16000)

Don't ask me to carry an oily rag like that.


In [83]:
timit = timit.remove_columns(["phonetic_detail", "word_detail", "dialect_region", "id", "sentence_type", "speaker_id"])

In [84]:
#display some random samples of the dataset
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))
    
show_random_elements(timit["train"].remove_columns(["audio", "file"]), num_examples=10)

Unnamed: 0,text
0,She had your dark suit in greasy wash water all year.
1,"Again, the analyticity of the two curves guarantees that such intervals exist."
2,"But if she wasn't interested, she'd just go back to the same life she'd left."
3,"His problem concerns longitudes, latitudes, and angular velocities."
4,Lullaby and goodnight his voice shook.
5,The rich should invest in black zircons instead of stylish shoes.
6,"Twenty-five, the sheik replied."
7,My desires are simple: give me one informative paragraph on the subject.
8,Along the main thoroughfares hardly a house had not been peppered.
9,Don't ask me to carry an oily rag like that.


### CREATE A VOCABOLARY

In [85]:
#We write a mapping function that concatenates all transcriptions into one long transcription and then 
#transforms the string into a set of chars. It is important to pass the argument batched=True to the map(...) 
#function so that the mapping function has access to all transcriptions at once.

import re
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'

def remove_special_characters(batch):
    batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower() + " "
    return batch

def extract_all_chars(batch):
    all_text = " ".join(batch["text"])
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}

In [86]:
timit = timit.map(remove_special_characters)

0ex [00:00, ?ex/s]

0ex [00:00, ?ex/s]

In [87]:
vocabs = timit.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=timit.column_names["train"])
vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

{'j': 0,
 'm': 1,
 'b': 2,
 'l': 3,
 "'": 4,
 'c': 5,
 'w': 6,
 'a': 7,
 'z': 8,
 'h': 9,
 'k': 10,
 'y': 11,
 't': 12,
 'r': 13,
 'p': 14,
 ' ': 15,
 'd': 16,
 'e': 17,
 'o': 18,
 'u': 19,
 's': 20,
 'n': 21,
 'q': 22,
 'x': 23,
 'i': 24,
 'g': 25,
 'v': 26,
 'f': 27}

In [88]:
#show_random_elements(librispeech_samples.remove_columns(["audio", "file"]), num_examples=10)

In [89]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)
                   

30

In [90]:
import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [91]:
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor, Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [92]:
augmented_data[0]["file"]

'/root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV'

In [93]:
augmented_data[0]["audio"]

{'wave': array([-0.00040229,  0.0003597 ,  0.00132362, ...,  0.        ,
         0.        ,  0.        ], dtype=float32),
 'sample_rate': 16000}

In [94]:
rand_int = random.randint(0, len(timit["train"]))

print("Target text:", timit["train"][rand_int]["text"])
print("Input array shape:", np.asarray(augmented_data[rand_int]["audio"]["wave"]).shape)
print("Sampling rate:", augmented_data[rand_int]["audio"]["sample_rate"])

Target text: he must rearrange matters so that two performers do not bump into each other 
Input array shape: (73012,)
Sampling rate: 16000


### PREPARE THE AUGMENTED DATESET FOR TRAINING 

In [95]:
#PREPARE THE TRAIN-AUGMENTED DATA----> for some reason, this code dosen't work if I create a function. 
inp = []
inp_l = []
for i in range(len(augmented_data)):

    # batched output is "un-batched" to ensure mapping is correct
    inputs = processor(augmented_data[i]['audio']["wave"], sampling_rate=augmented_data[i]['audio']["sample_rate"]).input_values
    input_lenght = len(inputs[0])

    inp.append(list(inputs))
    inp_l.append(input_lenght)


with processor.as_target_processor():
    labels = processor(timit['train']["text"]).input_ids
    
#getting the correct shape
shape = np.array(inp).shape[0]
inp = np.reshape(np.array(inp),(shape,))
dictt = {"input_values":list(inp), "input_length":inp_l, "labels": labels}


train_augmented = {'train':dictt}



In [96]:
#PREPARE THE TEST-NOT_AUGMENTED DATA
inp = []
inp_l = []
for i in range(len(test_data)):

    # batched output is "un-batched" to ensure mapping is correct
    inputs = processor(test_data[i]['audio']["wave"], sampling_rate=test_data[i]['audio']["sample_rate"]).input_values
    input_lenght = len(inputs[0])

    inp.append(list(inputs))
    inp_l.append(input_lenght)


with processor.as_target_processor():
    labels = processor(timit['test']["text"]).input_ids
    
#getting the correct shape
shape = np.array(inp).shape[0]
inp = np.reshape(np.array(inp),(shape,))
dict_test = {"input_values":list(inp), "input_length":inp_l, "labels": labels}

test_clean = {'test':dict_test}



In [97]:
from datasets import Dataset, ClassLabel, Sequence, Features, Value

dataset = timit
# using your `Dict` object
for k,v in train_augmented.items():
    dataset[k] = Dataset.from_dict(v)
    
for k,v in test_clean.items():
    dataset[k] = Dataset.from_dict(v)


In [98]:
dataset

DatasetDict({
    train: Dataset({
        features: ['input_values', 'input_length', 'labels'],
        num_rows: 4620
    })
    test: Dataset({
        features: ['input_values', 'input_length', 'labels'],
        num_rows: 1680
    })
})

In [99]:
np.array(dataset['train']["input_length"]).shape

(4620,)

In [100]:
max_input_length_in_sec = 4.0
dataset["train"] = dataset["train"].filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])

  0%|          | 0/5 [00:00<?, ?ba/s]

### TRAIN THE MODEL

In [101]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [102]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [103]:
wer_metric = load_metric("wer")

Downloading:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

In [104]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [105]:
len(processor.tokenizer)

30

In [106]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-xlsr-53", 
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.1,
    gradient_checkpointing=True,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

Downloading:   0%|          | 0.00/1.73k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.18G [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/wav2vec2-large-xlsr-53 were not used when initializing Wav2Vec2ForCTC: ['project_q.bias', 'quantizer.weight_proj.bias', 'quantizer.weight_proj.weight', 'quantizer.codevectors', 'project_q.weight', 'project_hid.weight', 'project_hid.bias']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are newly initialized: ['lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to u

In [107]:
model.freeze_feature_extractor()

In [108]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./wav2vec2-large-xlsr-WOLOF",
  group_by_length=True,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=40,
  save_steps=500,
  eval_steps=500,
  logging_steps=500,
  learning_rate=3e-4,
  warmup_steps=1000,
  save_total_limit=2,
)

In [109]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    tokenizer=processor.feature_extractor,
)

In [110]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running training *****
  Num examples = 3978
  Num Epochs = 40
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 2
  Total optimization steps = 4960
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /usr/local/src/pytorch/aten/src/ATen/native/BinaryOps.cpp:461.)
  return torch.floor_divide(self, other)


Step,Training Loss,Validation Loss,Wer
500,3.7585,2.876966,1.0
1000,1.1747,0.67054,0.517056
1500,0.3955,0.64642,0.454276
2000,0.2365,0.684367,0.439391
2500,0.1602,0.722129,0.432637
3000,0.1171,0.7605,0.436496
3500,0.094,0.725482,0.422162
4000,0.0785,0.720105,0.414927
4500,0.061,0.760018,0.41279


The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running Evaluation *****
  Num examples = 1680
  Batch size = 8
Saving model checkpoint to ./wav2vec2-large-xlsr-WOLOF/checkpoint-500
Configuration saved in ./wav2vec2-large-xlsr-WOLOF/checkpoint-500/config.json
Model weights saved in ./wav2vec2-large-xlsr-WOLOF/checkpoint-500/pytorch_model.bin
Configuration saved in ./wav2vec2-large-xlsr-WOLOF/checkpoint-500/preprocessor_config.json
The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running Evaluation *****
  Num examples = 1680
  Batch size = 8
Saving model checkpoint to ./wav2vec2-large-xlsr-WOLOF/checkpoint-1000
Configuration saved in ./wav2vec2-large-xlsr-WOLOF/checkpoint-1000/config.json
Model weights saved in ./wav2vec2-large-xlsr-WOLOF/checkpoint-1000/pytorch_model.bin
Configurat

TrainOutput(global_step=4960, training_loss=0.6175588900043119, metrics={'train_runtime': 14692.1577, 'train_samples_per_second': 10.83, 'train_steps_per_second': 0.338, 'total_flos': 1.3717006790225381e+19, 'train_loss': 0.6175588900043119, 'epoch': 40.0})

In [111]:
model.save_pretrained("wav2vec2-large-xlsr-WOLOF")
processor.save_pretrained("wav2vec2-large-xlsr-WOLOF")

Configuration saved in wav2vec2-large-xlsr-WOLOF/config.json
Model weights saved in wav2vec2-large-xlsr-WOLOF/pytorch_model.bin
Configuration saved in wav2vec2-large-xlsr-WOLOF/preprocessor_config.json
tokenizer config file saved in wav2vec2-large-xlsr-WOLOF/tokenizer_config.json
Special tokens file saved in wav2vec2-large-xlsr-WOLOF/special_tokens_map.json


In [112]:
model = Wav2Vec2ForCTC.from_pretrained("wav2vec2-large-xlsr-WOLOF").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("wav2vec2-large-xlsr-WOLOF")

loading configuration file wav2vec2-large-xlsr-WOLOF/config.json
Model config Wav2Vec2Config {
  "_name_or_path": "facebook/wav2vec2-large-xlsr-53",
  "activation_dropout": 0.0,
  "apply_spec_augment": true,
  "architectures": [
    "Wav2Vec2ForCTC"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 1,
  "classifier_proj_size": 256,
  "codevector_dim": 768,
  "contrastive_logits_temperature": 0.1,
  "conv_bias": true,
  "conv_dim": [
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "conv_kernel": [
    10,
    3,
    3,
    3,
    3,
    2,
    2
  ],
  "conv_stride": [
    5,
    2,
    2,
    2,
    2,
    2,
    2
  ],
  "ctc_loss_reduction": "mean",
  "ctc_zero_infinity": false,
  "diversity_loss_weight": 0.1,
  "do_stable_layer_norm": true,
  "eos_token_id": 2,
  "feat_extract_activation": "gelu",
  "feat_extract_dropout": 0.0,
  "feat_extract_norm": "layer",
  "feat_proj_dropout": 0.0,
  "feat_quantizer_dropout": 0.0,
  "final_dropout": 0.0,
  "hidden_act": "g

In [113]:
def map_to_result(batch):
    with torch.no_grad():
        input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
        logits = model(input_values).logits

    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_str"] = processor.batch_decode(pred_ids)[0]
    batch["text"] = processor.decode(batch["labels"], group_tokens=False)

    return batch

In [114]:
results = dataset["test"].map(map_to_result, remove_columns=dataset["test"].column_names)

0ex [00:00, ?ex/s]

In [115]:
print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))

Test WER: 0.289


In [121]:
show_random_elements(results)

Unnamed: 0,pred_str,text
0,baskit ball can be an entertanings forc,basketball can be an entertaining sport
1,she had your dark suit in greasy wash water all year,she had your dark suit in greasy wash water all year
2,caseum makes bones and teach strong,calcium makes bones and teeth strong
3,it was not exaced on caniclgam poitube but they cridnot tos selfreiet,it was not exactly panic they gave way to but they could not just sit there
4,got no bisness over here on a stake out anyway,got no business over here on a stakeout anyway
5,that diogram makes sence only after much study,that diagram makes sense only after much study
6,she had your dark suit in greasy wash water all year,she had your dark suit in greasy wash water all year
7,the best ray to learns to solv extra problems,the best way to learn is to solve extra problems
8,what shall these effects be,what shall these effects be
9,it shoer felt as if it were broken,his shoulder felt as if it were broken


In [117]:
model.to("cuda")

with torch.no_grad():
    logits = model(torch.tensor(dataset["test"][:1]["input_values"], device="cuda")).logits

pred_ids = torch.argmax(logits, dim=-1)

# convert ids to tokens
" ".join(processor.tokenizer.convert_ids_to_tokens(pred_ids[0].tolist()))

't [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] h h e | | | b [PAD] [PAD] u [PAD] n n [PAD] g g u u [PAD] l l [PAD] o [PAD] [PAD] [PAD] | | w w a a [PAD] s [PAD] | | [PAD] p [PAD] l l e [PAD] [PAD] [PAD] s [PAD] e n n t t [PAD] l l [PAD] y [PAD] | | | s [PAD] [PAD] [PAD] i [PAD] [PAD] [PAD] t [PAD] u u u [PAD] [PAD] [PAD] [PAD] a a [PAD] [PAD] t [PAD] e e d d | | n n e e a a r r | | t h e e | | | s s [PAD] [PAD] o o [PAD] [PAD] [PAD] [PAD] r r r [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] |'