Import all packages here

In [1]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from datasets import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
from torch.nn import CTCLoss
from tqdm import tqdm

from dataclasses import dataclass
from typing import Dict, List, Union

from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Trainer, TrainingArguments, Wav2Vec2CTCTokenizer, EarlyStoppingCallback

import librosa
from librosa.effects import trim
import librosa.display

from IPython.display import Audio

from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer

import pydub as pyd

import jiwer as jw

import pandas as pd 
import numpy as np 

import os

import matplotlib.pyplot as plt

import re


  from .autonotebook import tqdm as notebook_tqdm


Mounting Gdrive for the datasets

Setting up the Dataset

In [None]:
dataset_path = "dataset/"
metadata = "dataset.csv"

audio_directory = "dataset/"

# Create a dataframe for the transcript
dataframe = pd.read_csv(metadata)

# Preprocess transcript
def preprocess_text(text):
    text = text.upper()  # Convert text to uppercase
    text = re.sub(r"[^a-zA-Z0-9\s]", "", text)  # Remove non-alphanumeric characters (except spaces)
    return text

dataframe['clean_transcript'] = dataframe['Transcription'].apply(preprocess_text)

print(dataframe.head())

  File_Path Speaker         Transcription Session      clean_transcript
0     03M_1     03M  1 2 3 4 5 6 7 8 9 10       1  1 2 3 4 5 6 7 8 9 10
1     03M_2     03M                   ata       2                   ATA
2     03M_3     03M                   ana       3                   ANA
3     03M_4     03M                   ara       4                   ARA
4     03M_5     03M                  atha       5                  ATHA


Data Pre-Processing

In [3]:


# Function to preprocess audio and connect to transcripts
def combine_audio_with_transcript(directory, dataframe):
    audio_data = []    
     # Iterate over each row in the DataFrame
    for index, row in dataframe.iterrows():
        file_name = row['File_Path']  # Get the file name from the CSV (without .wav)
        transcript = row['clean_transcript']  # Get the transcript
        
        # Construct the full file path by combining directory and file name with .wav extension
        file_path = os.path.join(directory, f"{file_name}.wav")
        
        # Check if the file exists in the audio directory
        if os.path.exists(file_path):
            try:
                # Append the processed data along with the transcript
                audio_data.append({
                    "file_path": file_path,
                    "transcript": transcript,
                })
            except Exception as e:
                print(f"Error processing file {file_path}: {e}")
    
    return audio_data

# Preprocess audio files and connect to transcripts
audio_data_with_transcripts = combine_audio_with_transcript(audio_directory, dataframe)


Convert Dataset into Pytorch Dataset

In [4]:
audio_with_transcript_dataframe = pd.DataFrame(audio_data_with_transcripts)

for index, row in audio_with_transcript_dataframe.iterrows():
    file_path = row['file_path']
    try:
        # Load the audio file
        audio, sr = librosa.load(file_path, sr=None)  # sr=None to preserve the original sample rate
        
        # Display basic information about the audio
        print(f"Audio file {file_path} loaded successfully")
        print(f"Audio length: {len(audio)} samples")
        print(f"Sample rate: {sr} Hz\n")
    
    except Exception as e:
        print(f"Error loading {file_path}: {e}")

Audio file dataset/03M_1.wav loaded successfully
Audio length: 717953 samples
Sample rate: 48000 Hz

Audio file dataset/03M_2.wav loaded successfully
Audio length: 561018 samples
Sample rate: 48000 Hz

Audio file dataset/03M_3.wav loaded successfully
Audio length: 697276 samples
Sample rate: 48000 Hz

Audio file dataset/03M_4.wav loaded successfully
Audio length: 643847 samples
Sample rate: 48000 Hz

Audio file dataset/03M_5.wav loaded successfully
Audio length: 864068 samples
Sample rate: 48000 Hz

Audio file dataset/03M_6.wav loaded successfully
Audio length: 728750 samples
Sample rate: 48000 Hz

Audio file dataset/03M_7.wav loaded successfully
Audio length: 715758 samples
Sample rate: 48000 Hz

Audio file dataset/03M_8.wav loaded successfully
Audio length: 484327 samples
Sample rate: 48000 Hz

Audio file dataset/03M_9.wav loaded successfully
Audio length: 453342 samples
Sample rate: 48000 Hz

Audio file dataset/03M_10.wav loaded successfully
Audio length: 483563 samples
Sample rate:

Build The Model

In [5]:
# Initialize the processor
asr_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

# Initialize the model
asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
asr_tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/wav2vec2-large-960h")


print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA version:", torch.version.cuda)
    print("GPU Name:", torch.cuda.get_device_name(0))


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


CUDA available: True
CUDA version: 12.1
GPU Name: NVIDIA GeForce RTX 3060 Laptop GPU


In [6]:
sampling_rate=16000

def prepare_dataset(batch):
    """
    Prepare dataset with correct data types.
    """
    # Load the audio file
    audio, sr = librosa.load(batch['file_path'], sr=sampling_rate, mono=True)
    
    # Trim silence from the audio
    trimmed_audio, _ = librosa.effects.trim(audio, top_db=20)
    
    # Process audio
    audio_features = asr_processor(
        trimmed_audio, 
        sampling_rate=sampling_rate, 
        padding=False,
        return_tensors=None
    ).input_values[0]
    
    # Ensure audio features are float32
    audio_features = audio_features.astype(np.float32)
    
    # Process the transcript
    with asr_tokenizer.as_target_tokenizer():
        labels = asr_tokenizer(
            batch['transcript'],
            padding=False,
            return_tensors=None
        ).input_ids
    
    return {
        "input_values": audio_features,
        "labels": labels
    }

def custom_data_collator(batch):
    """
    Custom collator with correct data types.
    """
    # Get max length in the batch
    max_audio_length = max(len(x["input_values"]) for x in batch)
    max_label_length = max(len(x["labels"]) for x in batch)
    
    batch_audio = []
    batch_labels = []
    
    for sample in batch:
        # Pad audio
        audio_length = len(sample["input_values"])
        padded_audio = np.pad(
            sample["input_values"],
            (0, max_audio_length - audio_length),
            mode='constant',
            constant_values=0
        )
        batch_audio.append(padded_audio)
        
        # Pad labels
        label_length = len(sample["labels"])
        padded_labels = np.pad(
            sample["labels"],
            (0, max_label_length - label_length),
            mode='constant',
            constant_values=-100
        )
        batch_labels.append(padded_labels)
    
    # Convert to tensors with correct dtypes
    batch_audio = torch.tensor(batch_audio, dtype=torch.float32)
    batch_labels = torch.tensor(batch_labels, dtype=torch.long)
    
    return {
        "input_values": batch_audio,
        "labels": batch_labels
    }

In [7]:
dataset = Dataset.from_pandas(audio_with_transcript_dataframe)
processed_dataset = dataset.map(prepare_dataset, remove_columns=["file_path", "transcript"])

# Check if dataset is good
processed_dataframe = processed_dataset.to_pandas()
print(processed_dataframe.head())

Map:   0%|          | 0/2054 [00:00<?, ? examples/s]

Map: 100%|██████████| 2054/2054 [00:19<00:00, 103.78 examples/s]


                                        input_values  \
0  [-0.08597469, -0.06963344, -0.093899734, -0.04...   
1  [-0.013693145, 0.00078774075, -0.010232744, -0...   
2  [-0.22992292, -0.20895448, -0.16145737, -0.156...   
3  [-0.025115486, -0.03066014, -0.0062442967, -0....   
4  [-0.10101376, -0.12410567, -0.19890745, -0.223...   

                                              labels  
0  [3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, ...  
1                                          [7, 6, 7]  
2                                          [7, 9, 7]  
3                                         [7, 13, 7]  
4                                      [7, 6, 11, 7]  


In [8]:
split_dataset = processed_dataset.train_test_split(test_size=0.3, seed=42)

train_dataset = split_dataset["train"]
val_dataset = split_dataset["test"]

# Print dataset details
print("Training dataset size:", len(train_dataset))
print("Validation dataset size:", len(val_dataset))

Training dataset size: 1437
Validation dataset size: 617


In [9]:
train_samples = len(train_dataset)
batch_size = 4  # Your batch size
n_gpu = 1  # Assuming single GPU
gradient_accumulation = 1  # Default value

steps_per_epoch = train_samples // (batch_size * n_gpu * gradient_accumulation)

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=4,
    dataloader_num_workers=0,
    fp16=True,
    num_train_epochs=50,
    logging_steps=steps_per_epoch,
    eval_steps=steps_per_epoch,
    save_steps=steps_per_epoch,
    learning_rate=1e-7,
    lr_scheduler_type="linear",  # Options: 'linear', 'cosine', 'constant', etc.

    load_best_model_at_end=True,  # Enable saving the best model
    metric_for_best_model="eval_loss",  # Replace with your desired metric
    greater_is_better=False,
    evaluation_strategy="epoch",  # Evaluate at the end of each epoch
    save_strategy="epoch",
)

trainer = Trainer(
    model=asr_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=custom_data_collator,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.01)
    ],
)




Training the Model

In [10]:
trainer.train()

  batch_audio = torch.tensor(batch_audio, dtype=torch.float32)
  2%|▏         | 360/18000 [01:42<59:23,  4.95it/s]  

{'loss': 328.1096, 'grad_norm': 8283.017578125, 'learning_rate': 9.812777777777778e-08, 'epoch': 1.0}


                                                   
  2%|▏         | 360/18000 [02:49<59:23,  4.95it/s]

{'eval_loss': 882.4044799804688, 'eval_runtime': 66.433, 'eval_samples_per_second': 9.288, 'eval_steps_per_second': 1.174, 'epoch': 1.0}


  4%|▍         | 719/18000 [04:25<1:06:03,  4.36it/s]  

{'loss': 296.1453, 'grad_norm': 13944.5966796875, 'learning_rate': 9.613333333333333e-08, 'epoch': 1.99}


                                                     
  4%|▍         | 720/18000 [04:52<55:51,  5.16it/s]

{'eval_loss': 752.8445434570312, 'eval_runtime': 27.6285, 'eval_samples_per_second': 22.332, 'eval_steps_per_second': 2.823, 'epoch': 2.0}


  6%|▌         | 1077/18000 [06:28<1:15:40,  3.73it/s]

{'loss': 267.09, 'grad_norm': 8257.0537109375, 'learning_rate': 9.413888888888889e-08, 'epoch': 2.99}


                                                      
  6%|▌         | 1080/18000 [06:56<1:02:29,  4.51it/s]

{'eval_loss': 657.5489501953125, 'eval_runtime': 27.5921, 'eval_samples_per_second': 22.361, 'eval_steps_per_second': 2.827, 'epoch': 3.0}


  8%|▊         | 1436/18000 [08:29<1:10:04,  3.94it/s] 

{'loss': 252.0491, 'grad_norm': 34423.39453125, 'learning_rate': 9.214444444444444e-08, 'epoch': 3.99}


                                                      
  8%|▊         | 1440/18000 [08:56<1:01:13,  4.51it/s]

{'eval_loss': 584.2203369140625, 'eval_runtime': 26.011, 'eval_samples_per_second': 23.721, 'eval_steps_per_second': 2.999, 'epoch': 4.0}


 10%|▉         | 1796/18000 [10:28<1:00:32,  4.46it/s] 

{'loss': 237.6925, 'grad_norm': 14945.193359375, 'learning_rate': 9.015e-08, 'epoch': 4.99}


                                                      
 10%|█         | 1800/18000 [10:55<54:03,  4.99it/s]

{'eval_loss': 528.9442138671875, 'eval_runtime': 26.1875, 'eval_samples_per_second': 23.561, 'eval_steps_per_second': 2.979, 'epoch': 5.0}


 12%|█▏        | 2154/18000 [12:28<1:00:09,  4.39it/s] 

{'loss': 225.9935, 'grad_norm': 15181.80078125, 'learning_rate': 8.815555555555556e-08, 'epoch': 5.98}


                                                      
 12%|█▏        | 2160/18000 [12:56<56:51,  4.64it/s]

{'eval_loss': 495.9640808105469, 'eval_runtime': 26.177, 'eval_samples_per_second': 23.57, 'eval_steps_per_second': 2.98, 'epoch': 6.0}


 14%|█▍        | 2514/18000 [14:29<1:00:13,  4.29it/s] 

{'loss': 221.5845, 'grad_norm': 3071.61083984375, 'learning_rate': 8.61611111111111e-08, 'epoch': 6.98}


                                                      
 14%|█▍        | 2520/18000 [14:57<55:48,  4.62it/s]

{'eval_loss': 471.59161376953125, 'eval_runtime': 27.4651, 'eval_samples_per_second': 22.465, 'eval_steps_per_second': 2.84, 'epoch': 7.0}


 16%|█▌        | 2872/18000 [16:29<1:07:05,  3.76it/s] 

{'loss': 214.8292, 'grad_norm': 4771.80615234375, 'learning_rate': 8.416666666666666e-08, 'epoch': 7.98}


                                                      
 16%|█▌        | 2880/18000 [16:57<59:23,  4.24it/s]

{'eval_loss': 457.1165771484375, 'eval_runtime': 26.3047, 'eval_samples_per_second': 23.456, 'eval_steps_per_second': 2.965, 'epoch': 8.0}


 18%|█▊        | 3231/18000 [18:28<1:04:15,  3.83it/s] 

{'loss': 215.2543, 'grad_norm': 8185.98681640625, 'learning_rate': 8.217222222222222e-08, 'epoch': 8.97}


                                                      
 18%|█▊        | 3240/18000 [18:57<48:56,  5.03it/s]

{'eval_loss': 444.1304626464844, 'eval_runtime': 26.6075, 'eval_samples_per_second': 23.189, 'eval_steps_per_second': 2.932, 'epoch': 9.0}


 20%|█▉        | 3590/18000 [20:28<51:51,  4.63it/s]   

{'loss': 212.3584, 'grad_norm': 4621.6474609375, 'learning_rate': 8.018333333333333e-08, 'epoch': 9.97}


                                                    
 20%|██        | 3600/18000 [20:57<50:04,  4.79it/s]

{'eval_loss': 434.31634521484375, 'eval_runtime': 26.6887, 'eval_samples_per_second': 23.118, 'eval_steps_per_second': 2.923, 'epoch': 10.0}


 22%|██▏       | 3949/18000 [22:28<1:00:55,  3.84it/s] 

{'loss': 209.9837, 'grad_norm': 5041.05615234375, 'learning_rate': 7.818888888888889e-08, 'epoch': 10.97}


                                                      
 22%|██▏       | 3960/18000 [22:56<46:55,  4.99it/s]

{'eval_loss': 426.9784240722656, 'eval_runtime': 26.6031, 'eval_samples_per_second': 23.193, 'eval_steps_per_second': 2.932, 'epoch': 11.0}


 24%|██▍       | 4308/18000 [24:26<56:28,  4.04it/s]   

{'loss': 207.4234, 'grad_norm': 2475.098876953125, 'learning_rate': 7.619444444444444e-08, 'epoch': 11.97}


                                                    
 24%|██▍       | 4320/18000 [24:55<48:15,  4.72it/s]

{'eval_loss': 421.3515625, 'eval_runtime': 26.4001, 'eval_samples_per_second': 23.371, 'eval_steps_per_second': 2.955, 'epoch': 12.0}


 26%|██▌       | 4667/18000 [26:27<54:27,  4.08it/s]   

{'loss': 205.6322, 'grad_norm': 4558.8173828125, 'learning_rate': 7.419999999999999e-08, 'epoch': 12.96}


                                                      
 26%|██▌       | 4680/18000 [26:57<56:17,  3.94it/s]

{'eval_loss': 417.4057312011719, 'eval_runtime': 26.4729, 'eval_samples_per_second': 23.307, 'eval_steps_per_second': 2.946, 'epoch': 13.0}


 28%|██▊       | 5026/18000 [28:27<57:09,  3.78it/s]   

{'loss': 203.5777, 'grad_norm': 5716.26904296875, 'learning_rate': 7.220555555555555e-08, 'epoch': 13.96}


                                                    
 28%|██▊       | 5040/18000 [28:57<44:37,  4.84it/s]

{'eval_loss': 412.7198181152344, 'eval_runtime': 26.3902, 'eval_samples_per_second': 23.38, 'eval_steps_per_second': 2.956, 'epoch': 14.0}


 30%|██▉       | 5386/18000 [30:27<57:46,  3.64it/s]   

{'loss': 201.5002, 'grad_norm': 4997.00146484375, 'learning_rate': 7.021111111111111e-08, 'epoch': 14.96}


                                                      
 30%|███       | 5400/18000 [30:57<47:30,  4.42it/s]

{'eval_loss': 409.8057556152344, 'eval_runtime': 26.5243, 'eval_samples_per_second': 23.262, 'eval_steps_per_second': 2.941, 'epoch': 15.0}


 32%|███▏      | 5745/18000 [32:26<43:31,  4.69it/s]   

{'loss': 201.1082, 'grad_norm': 1932.05126953125, 'learning_rate': 6.821666666666667e-08, 'epoch': 15.96}


                                                      
 32%|███▏      | 5760/18000 [32:56<1:00:19,  3.38it/s]

{'eval_loss': 406.5759582519531, 'eval_runtime': 26.5151, 'eval_samples_per_second': 23.27, 'eval_steps_per_second': 2.942, 'epoch': 16.0}


 34%|███▍      | 6103/18000 [34:26<46:20,  4.28it/s]   

{'loss': 200.1062, 'grad_norm': 7155.6982421875, 'learning_rate': 6.622777777777777e-08, 'epoch': 16.95}


                                                    
 34%|███▍      | 6120/18000 [34:57<45:48,  4.32it/s]

{'eval_loss': 405.2737121582031, 'eval_runtime': 26.569, 'eval_samples_per_second': 23.223, 'eval_steps_per_second': 2.936, 'epoch': 17.0}


 36%|███▌      | 6462/18000 [36:27<49:37,  3.87it/s]   

{'loss': 201.7015, 'grad_norm': 3077.346435546875, 'learning_rate': 6.423333333333333e-08, 'epoch': 17.95}


                                                    
 36%|███▌      | 6480/18000 [36:57<38:57,  4.93it/s]

{'eval_loss': 401.7613525390625, 'eval_runtime': 26.4357, 'eval_samples_per_second': 23.34, 'eval_steps_per_second': 2.951, 'epoch': 18.0}


 38%|███▊      | 6821/18000 [38:27<47:54,  3.89it/s]   

{'loss': 197.4869, 'grad_norm': 6337.1474609375, 'learning_rate': 6.223888888888889e-08, 'epoch': 18.95}


                                                    
 38%|███▊      | 6840/18000 [38:58<39:34,  4.70it/s]

{'eval_loss': 399.51104736328125, 'eval_runtime': 26.3742, 'eval_samples_per_second': 23.394, 'eval_steps_per_second': 2.957, 'epoch': 19.0}


 40%|███▉      | 7180/18000 [40:29<56:35,  3.19it/s]   

{'loss': 197.3745, 'grad_norm': 5996.64599609375, 'learning_rate': 6.024444444444444e-08, 'epoch': 19.94}


                                                    
 40%|████      | 7200/18000 [41:01<33:07,  5.43it/s]

{'eval_loss': 398.4128723144531, 'eval_runtime': 26.7979, 'eval_samples_per_second': 23.024, 'eval_steps_per_second': 2.911, 'epoch': 20.0}


 42%|████▏     | 7539/18000 [42:32<39:57,  4.36it/s]   

{'loss': 194.6766, 'grad_norm': 4359.08447265625, 'learning_rate': 5.825e-08, 'epoch': 20.94}


                                                      
 42%|████▏     | 7560/18000 [43:04<44:33,  3.90it/s]

{'eval_loss': 397.317626953125, 'eval_runtime': 26.3557, 'eval_samples_per_second': 23.41, 'eval_steps_per_second': 2.96, 'epoch': 21.0}


 44%|████▍     | 7898/18000 [44:32<40:35,  4.15it/s]   

{'loss': 196.6336, 'grad_norm': 5146.947265625, 'learning_rate': 5.6255555555555554e-08, 'epoch': 21.94}


                                                    
 44%|████▍     | 7920/18000 [45:04<34:33,  4.86it/s]

{'eval_loss': 395.0817565917969, 'eval_runtime': 27.1043, 'eval_samples_per_second': 22.764, 'eval_steps_per_second': 2.878, 'epoch': 22.0}


 46%|████▌     | 8257/18000 [46:34<36:53,  4.40it/s]   

{'loss': 193.8911, 'grad_norm': 4440.58447265625, 'learning_rate': 5.426111111111111e-08, 'epoch': 22.94}


                                                    
 46%|████▌     | 8280/18000 [47:06<43:18,  3.74it/s]

{'eval_loss': 393.35321044921875, 'eval_runtime': 26.5092, 'eval_samples_per_second': 23.275, 'eval_steps_per_second': 2.942, 'epoch': 23.0}


 48%|████▊     | 8616/18000 [48:34<45:26,  3.44it/s]   

{'loss': 193.4066, 'grad_norm': 3239.440185546875, 'learning_rate': 5.226666666666666e-08, 'epoch': 23.93}


                                                    
 48%|████▊     | 8640/18000 [49:06<33:36,  4.64it/s]

{'eval_loss': 392.02227783203125, 'eval_runtime': 26.4954, 'eval_samples_per_second': 23.287, 'eval_steps_per_second': 2.944, 'epoch': 24.0}


 50%|████▉     | 8975/18000 [50:35<41:30,  3.62it/s]   

{'loss': 192.9606, 'grad_norm': 23027.962890625, 'learning_rate': 5.027222222222222e-08, 'epoch': 24.93}


                                                    
 50%|█████     | 9000/18000 [51:08<32:50,  4.57it/s]

{'eval_loss': 391.0492248535156, 'eval_runtime': 26.3523, 'eval_samples_per_second': 23.413, 'eval_steps_per_second': 2.96, 'epoch': 25.0}


 52%|█████▏    | 9334/18000 [52:35<39:03,  3.70it/s]   

{'loss': 192.9902, 'grad_norm': 5021.11474609375, 'learning_rate': 4.828333333333333e-08, 'epoch': 25.93}


                                                    
 52%|█████▏    | 9360/18000 [53:08<35:03,  4.11it/s]

{'eval_loss': 389.66607666015625, 'eval_runtime': 26.5498, 'eval_samples_per_second': 23.239, 'eval_steps_per_second': 2.938, 'epoch': 26.0}


 54%|█████▍    | 9693/18000 [54:36<36:31,  3.79it/s]   

{'loss': 191.399, 'grad_norm': 2608.0673828125, 'learning_rate': 4.628888888888889e-08, 'epoch': 26.93}


                                                    
 54%|█████▍    | 9720/18000 [55:09<26:05,  5.29it/s]

{'eval_loss': 388.42327880859375, 'eval_runtime': 26.6997, 'eval_samples_per_second': 23.109, 'eval_steps_per_second': 2.921, 'epoch': 27.0}


 56%|█████▌    | 10052/18000 [56:38<31:54,  4.15it/s]  

{'loss': 191.4111, 'grad_norm': 3283.738525390625, 'learning_rate': 4.43e-08, 'epoch': 27.92}


                                                     
 56%|█████▌    | 10080/18000 [57:11<30:48,  4.28it/s]

{'eval_loss': 386.55938720703125, 'eval_runtime': 26.6034, 'eval_samples_per_second': 23.193, 'eval_steps_per_second': 2.932, 'epoch': 28.0}


 58%|█████▊    | 10411/18000 [58:37<33:04,  3.82it/s]   

{'loss': 193.5593, 'grad_norm': 3346.573486328125, 'learning_rate': 4.230555555555556e-08, 'epoch': 28.92}


                                                     
 58%|█████▊    | 10440/18000 [59:12<33:34,  3.75it/s]

{'eval_loss': 385.5793762207031, 'eval_runtime': 26.7165, 'eval_samples_per_second': 23.094, 'eval_steps_per_second': 2.92, 'epoch': 29.0}


 60%|█████▉    | 10770/18000 [1:00:39<27:53,  4.32it/s] 

{'loss': 191.5001, 'grad_norm': 3020.443359375, 'learning_rate': 4.0311111111111104e-08, 'epoch': 29.92}


                                                       
 60%|██████    | 10800/18000 [1:01:14<26:35,  4.51it/s]

{'eval_loss': 384.60601806640625, 'eval_runtime': 26.8072, 'eval_samples_per_second': 23.016, 'eval_steps_per_second': 2.91, 'epoch': 30.0}


 62%|██████▏   | 11129/18000 [1:02:41<27:25,  4.18it/s]   

{'loss': 188.7158, 'grad_norm': 14116.6162109375, 'learning_rate': 3.831666666666666e-08, 'epoch': 30.91}


                                                       
 62%|██████▏   | 11160/18000 [1:03:15<27:39,  4.12it/s]

{'eval_loss': 383.5542297363281, 'eval_runtime': 26.8843, 'eval_samples_per_second': 22.95, 'eval_steps_per_second': 2.901, 'epoch': 31.0}


 64%|██████▍   | 11488/18000 [1:04:43<26:54,  4.03it/s]   

{'loss': 190.2566, 'grad_norm': 6325.67041015625, 'learning_rate': 3.632222222222222e-08, 'epoch': 31.91}


                                                       
 64%|██████▍   | 11520/18000 [1:05:18<23:41,  4.56it/s]

{'eval_loss': 381.72442626953125, 'eval_runtime': 26.9582, 'eval_samples_per_second': 22.887, 'eval_steps_per_second': 2.893, 'epoch': 32.0}


 66%|██████▌   | 11847/18000 [1:06:45<27:20,  3.75it/s]   

{'loss': 189.2523, 'grad_norm': 3301.321044921875, 'learning_rate': 3.432777777777778e-08, 'epoch': 32.91}


                                                       
 66%|██████▌   | 11880/18000 [1:07:20<23:18,  4.38it/s]

{'eval_loss': 380.8316955566406, 'eval_runtime': 26.8081, 'eval_samples_per_second': 23.015, 'eval_steps_per_second': 2.91, 'epoch': 33.0}


 68%|██████▊   | 12206/18000 [1:08:46<22:45,  4.24it/s]   

{'loss': 187.671, 'grad_norm': 5312.08349609375, 'learning_rate': 3.233333333333333e-08, 'epoch': 33.91}


                                                       
 68%|██████▊   | 12240/18000 [1:09:21<20:39,  4.65it/s]

{'eval_loss': 380.50958251953125, 'eval_runtime': 26.8465, 'eval_samples_per_second': 22.983, 'eval_steps_per_second': 2.905, 'epoch': 34.0}


 70%|██████▉   | 12565/18000 [1:10:47<23:24,  3.87it/s]   

{'loss': 187.0565, 'grad_norm': 9392.1220703125, 'learning_rate': 3.0338888888888884e-08, 'epoch': 34.9}


                                                       
 70%|███████   | 12600/18000 [1:11:23<16:56,  5.31it/s]

{'eval_loss': 380.5693054199219, 'eval_runtime': 27.1793, 'eval_samples_per_second': 22.701, 'eval_steps_per_second': 2.87, 'epoch': 35.0}


 72%|███████▏  | 12924/18000 [1:12:49<20:41,  4.09it/s]   

{'loss': 188.2927, 'grad_norm': 2747.873291015625, 'learning_rate': 2.8344444444444442e-08, 'epoch': 35.9}


                                                       
 72%|███████▏  | 12960/18000 [1:13:25<14:58,  5.61it/s]

{'eval_loss': 379.88580322265625, 'eval_runtime': 26.8948, 'eval_samples_per_second': 22.941, 'eval_steps_per_second': 2.9, 'epoch': 36.0}


 74%|███████▍  | 13283/18000 [1:14:49<22:55,  3.43it/s]   

{'loss': 186.4734, 'grad_norm': 6466.2021484375, 'learning_rate': 2.635e-08, 'epoch': 36.9}


                                                       
 74%|███████▍  | 13320/18000 [1:15:26<18:20,  4.25it/s]

{'eval_loss': 379.9446105957031, 'eval_runtime': 27.1054, 'eval_samples_per_second': 22.763, 'eval_steps_per_second': 2.878, 'epoch': 37.0}


 76%|███████▌  | 13642/18000 [1:16:52<19:46,  3.67it/s]  

{'loss': 185.9143, 'grad_norm': 4761.27978515625, 'learning_rate': 2.4355555555555553e-08, 'epoch': 37.89}


                                                       
 76%|███████▌  | 13680/18000 [1:17:28<14:02,  5.13it/s]

{'eval_loss': 379.358642578125, 'eval_runtime': 26.8397, 'eval_samples_per_second': 22.988, 'eval_steps_per_second': 2.906, 'epoch': 38.0}


 78%|███████▊  | 14001/18000 [1:18:53<19:55,  3.34it/s]   

{'loss': 188.597, 'grad_norm': 13057.791015625, 'learning_rate': 2.2361111111111112e-08, 'epoch': 38.89}


                                                       
 78%|███████▊  | 14040/18000 [1:19:30<12:29,  5.29it/s]

{'eval_loss': 378.5219421386719, 'eval_runtime': 26.7196, 'eval_samples_per_second': 23.092, 'eval_steps_per_second': 2.919, 'epoch': 39.0}


 80%|███████▉  | 14360/18000 [1:20:56<14:51,  4.08it/s]   

{'loss': 187.2775, 'grad_norm': 7786.59765625, 'learning_rate': 2.037222222222222e-08, 'epoch': 39.89}


                                                       
 80%|████████  | 14400/18000 [1:21:33<12:48,  4.69it/s]

{'eval_loss': 377.9352111816406, 'eval_runtime': 26.9131, 'eval_samples_per_second': 22.926, 'eval_steps_per_second': 2.898, 'epoch': 40.0}


 82%|████████▏ | 14719/18000 [1:22:58<15:50,  3.45it/s]  

{'loss': 185.7197, 'grad_norm': 7808.72265625, 'learning_rate': 1.837777777777778e-08, 'epoch': 40.89}


                                                       
 82%|████████▏ | 14760/18000 [1:23:35<12:04,  4.47it/s]

{'eval_loss': 377.92657470703125, 'eval_runtime': 26.9246, 'eval_samples_per_second': 22.916, 'eval_steps_per_second': 2.897, 'epoch': 41.0}


 84%|████████▍ | 15078/18000 [1:24:59<11:24,  4.27it/s]  

{'loss': 187.9294, 'grad_norm': 3270.6123046875, 'learning_rate': 1.638333333333333e-08, 'epoch': 41.88}


                                                       
 84%|████████▍ | 15120/18000 [1:25:37<11:06,  4.32it/s]

{'eval_loss': 377.3719177246094, 'eval_runtime': 27.0016, 'eval_samples_per_second': 22.851, 'eval_steps_per_second': 2.889, 'epoch': 42.0}


 86%|████████▌ | 15437/18000 [1:27:05<10:20,  4.13it/s]  

{'loss': 187.6281, 'grad_norm': 3788.36767578125, 'learning_rate': 1.438888888888889e-08, 'epoch': 42.88}


                                                       
 86%|████████▌ | 15480/18000 [1:27:43<08:06,  5.18it/s]

{'eval_loss': 377.3569641113281, 'eval_runtime': 26.8791, 'eval_samples_per_second': 22.955, 'eval_steps_per_second': 2.902, 'epoch': 43.0}


 88%|████████▊ | 15796/18000 [1:29:09<09:29,  3.87it/s]  

{'loss': 184.3207, 'grad_norm': 6119.443359375, 'learning_rate': 1.2399999999999999e-08, 'epoch': 43.88}


                                                       
 88%|████████▊ | 15840/18000 [1:29:47<07:30,  4.80it/s]

{'eval_loss': 377.283447265625, 'eval_runtime': 26.8177, 'eval_samples_per_second': 23.007, 'eval_steps_per_second': 2.909, 'epoch': 44.0}


 90%|████████▉ | 16155/18000 [1:31:11<07:44,  3.97it/s]  

{'loss': 186.2444, 'grad_norm': 3744.87060546875, 'learning_rate': 1.0405555555555556e-08, 'epoch': 44.88}


                                                       
 90%|█████████ | 16200/18000 [1:31:50<07:04,  4.24it/s]

{'eval_loss': 377.18890380859375, 'eval_runtime': 26.9113, 'eval_samples_per_second': 22.927, 'eval_steps_per_second': 2.898, 'epoch': 45.0}


 92%|█████████▏| 16514/18000 [1:33:14<05:46,  4.29it/s]  

{'loss': 188.0058, 'grad_norm': 5054.2548828125, 'learning_rate': 8.411111111111111e-09, 'epoch': 45.87}


                                                       
 92%|█████████▏| 16560/18000 [1:33:52<06:04,  3.95it/s]

{'eval_loss': 376.95855712890625, 'eval_runtime': 26.8731, 'eval_samples_per_second': 22.96, 'eval_steps_per_second': 2.903, 'epoch': 46.0}


 94%|█████████▎| 16873/18000 [1:35:15<04:07,  4.55it/s]  

{'loss': 182.8389, 'grad_norm': 6639.6298828125, 'learning_rate': 6.416666666666666e-09, 'epoch': 46.87}


                                                       
 94%|█████████▍| 16920/18000 [1:35:54<04:04,  4.42it/s]

{'eval_loss': 376.906005859375, 'eval_runtime': 27.045, 'eval_samples_per_second': 22.814, 'eval_steps_per_second': 2.884, 'epoch': 47.0}


 96%|█████████▌| 17232/18000 [1:37:16<03:33,  3.60it/s]  

{'loss': 185.858, 'grad_norm': 5391.82470703125, 'learning_rate': 4.422222222222222e-09, 'epoch': 47.87}


                                                       
 96%|█████████▌| 17280/18000 [1:37:55<03:07,  3.84it/s]

{'eval_loss': 376.7323913574219, 'eval_runtime': 26.6632, 'eval_samples_per_second': 23.141, 'eval_steps_per_second': 2.925, 'epoch': 48.0}


 98%|█████████▊| 17591/18000 [1:39:20<01:43,  3.94it/s]  

{'loss': 185.1453, 'grad_norm': 9355.12890625, 'learning_rate': 2.4277777777777777e-09, 'epoch': 48.86}


                                                       
 98%|█████████▊| 17640/18000 [1:39:59<01:20,  4.48it/s]

{'eval_loss': 376.70648193359375, 'eval_runtime': 27.0725, 'eval_samples_per_second': 22.791, 'eval_steps_per_second': 2.881, 'epoch': 49.0}


100%|█████████▉| 17950/18000 [1:41:21<00:13,  3.82it/s]

{'loss': 182.9929, 'grad_norm': 5640.31640625, 'learning_rate': 4.388888888888889e-10, 'epoch': 49.86}


                                                       
100%|██████████| 18000/18000 [1:42:04<00:00,  3.73it/s]

{'eval_loss': 376.72137451171875, 'eval_runtime': 27.0789, 'eval_samples_per_second': 22.785, 'eval_steps_per_second': 2.88, 'epoch': 50.0}


100%|██████████| 18000/18000 [1:42:07<00:00,  2.94it/s]

{'train_runtime': 6127.719, 'train_samples_per_second': 11.725, 'train_steps_per_second': 2.937, 'train_loss': 203.2487152777778, 'epoch': 50.0}





TrainOutput(global_step=18000, training_loss=203.2487152777778, metrics={'train_runtime': 6127.719, 'train_samples_per_second': 11.725, 'train_steps_per_second': 2.937, 'total_flos': 3.4487335393815025e+18, 'train_loss': 203.2487152777778, 'epoch': 50.0})

Training Results

Visualization of Results using MatPlotLib

Save the model

More evaluation 