In [None]:
#Transformer version of TacoTron
!pip install nltk
import nltk
nltk.download('cmudict')

In [2]:
import torch
from torch import nn
from pathlib import Path
import random


# from tokenizers import Tokenizer
from torch.utils.data import Dataset, DataLoader
from nltk.corpus import cmudict
pronouncing_dict = cmudict.dict()



In [None]:
word = "hello"
phonemes = pronouncing_dict.get(word.lower(), ["Not found"])
print(phonemes)  # Output: [['HH', 'AH0', 'L', 'OW1']]

In [6]:

def text_to_phonemes(word):
 
    word = word.lower()
    if word in pronouncing_dict:
        return pronouncing_dict[word][0]  # Take the first pronunciation if multiple exist
    else:
        return None  # Handle OOV words separately
phoneme_vocab = set()
for word, phonemes in pronouncing_dict.items():
    for phoneme_seq in phonemes:
        phoneme_vocab.update(phoneme_seq)

phoneme_vocab = sorted(list(phoneme_vocab))  # Sort for consistent indexing
phoneme_vocab.insert(0, '[PAD]')
phoneme_vocab.insert(1, '[EOS]')
phoneme_vocab.insert(2, '[UNK]')
phoneme_to_id = {phoneme: idx for idx, phoneme in enumerate(phoneme_vocab)}

vocab_size = len(phoneme_vocab)  # Number of unique phonemes

def phonemes_to_indices(phoneme_seq):
    
    return [phoneme_to_id[p] for p in phoneme_seq if p in phoneme_to_id]


In [None]:
vocab_size

In [8]:

# import re
HF_TOKEN = '...'



In [None]:
import wandb
!wandb login

In [11]:
#Hyperparameters
epochs=10
block_size = 80
batch_size = 32
# src_vocab_size = None
src_vocab_size = vocab_size
phenome_embeddings_dims = 512
embeddings_dims = phenome_embeddings_dims
prenet_encoder_embeddings_dims = 512
attn_dropout = 0.1
no_of_heads = 4 #IMP needs to be thoroughly calculated
dropout = 0.1
# epochs = 3
max_lr = 6e-4
no_of_decoder_layers = 8 #IMP needs to be thoroughly calculated
attn_dropout = 0.1
weight_decay_optim = 0.01
log_mel_features = 80
kernel_size = 5
stride = (2,10)
sr = 16000
device= 'cuda:0'
SAMPLING_RATE=16000
N_MELS = 80  # 80-channel Mel spectrogram
WINDOW_DURATION = 0.050  # 25 milliseconds
STRIDE_DURATION = 0.0125  # 10 milliseconds
max_t = 512
n_channels = N_MELS
clip = 1.0
embeddings_dims_decoder = 256

In [12]:
torch.set_default_device(device)

In [None]:

!pip install datasets
from tabnanny import verbose
from datasets import load_dataset

gs = load_dataset("keithito/lj_speech", token=HF_TOKEN)


print(gs)


audio_input = gs['train'][0]["audio"]
transcription = gs["train"][0]["text"]

In [None]:

MAX_DURATION_IN_SECONDS = 10

gs = gs['train'].train_test_split(test_size=0.2)
# print(dataset)
# train_data, val_data = dataset['train'], dataset['test']

import librosa
from tqdm import tqdm
def is_audio_length_in_range(input_length):
    return input_length < MAX_DURATION_IN_SECONDS

train_new_column = []

for x in tqdm(range(len(gs['train']))):
    train_new_column.append(librosa.get_duration(path=gs['train'][x]['audio']['path']))

gs_ = gs['train'].add_column("duration", train_new_column)


gs_ = gs_.filter(is_audio_length_in_range, input_columns=["duration"])

truncated_gs_train = gs_
# truncated_gs_train = gs_.remove_columns(["duration"])
# truncated_gs



val_new_column = []
# new_column = [librosa.get_duration(path=x) ]]
for x in tqdm(range(len(gs['test']))):
    val_new_column.append(librosa.get_duration(path=gs['test'][x]['audio']['path']))

gs_ = gs['test'].add_column("duration", val_new_column)


gs_ = gs_.filter(is_audio_length_in_range, input_columns=["duration"])

truncated_gs_val = gs_
# truncated_gs_val = gs_.remove_columns(["duration"])
# truncated_gs

In [None]:

import numpy as np


n_fft = int(WINDOW_DURATION * MAX_DURATION_IN_SECONDS * SAMPLING_RATE)
hop_length = int(STRIDE_DURATION * MAX_DURATION_IN_SECONDS * SAMPLING_RATE)

train_outputs = []
train_texts = []
train_duration = []
val_outputs = []
val_texts = []
val_duration = []
# train_texts = []
for i in tqdm(range(len(truncated_gs_train))):
  S = librosa.feature.melspectrogram(
      y=truncated_gs_train[i]['audio']['array'],
      sr=SAMPLING_RATE,
      n_mels=N_MELS,
      n_fft=n_fft,
      hop_length=hop_length,
      win_length=n_fft,
      fmax=SAMPLING_RATE // 2
  )


  S_dB = librosa.power_to_db(S, ref=np.max)
  train_outputs.append(S_dB)
  train_texts.append(truncated_gs_train[i]['normalized_text'])
  train_duration.append(truncated_gs_train[i]['duration'])

val_outputs = []
val_texts = []
for i in tqdm(range(len(truncated_gs_val))):
  S = librosa.feature.melspectrogram(
      y=truncated_gs_val[i]['audio']['array'],
      sr=SAMPLING_RATE,
      n_mels=N_MELS,
      n_fft=n_fft,
      hop_length=hop_length,
      win_length=n_fft,
      fmax=SAMPLING_RATE // 2
  )


  S_dB = librosa.power_to_db(S, ref=np.max)
  val_outputs.append(S_dB)
  val_texts.append(truncated_gs_val[i]['text'])
  val_duration.append(truncated_gs_val[i]['duration'])

In [19]:
# import math
import re
# print(round(random.random(), 1))
class TTSDataset(Dataset):

  def __init__(self, outputs, texts, duration):

    self.data = outputs
    self.texts = texts
    self.max_t = max_t
    self.duration = duration
  def __len__(self):
    return len(self.data)


  def pad_phoneme_sequence(self, phoneme_seq, max_length):
        """Pads phoneme sequences to max_length."""
        pad_token = 0
        if len(phoneme_seq) < max_length:
            phoneme_seq += [pad_token] * (max_length - len(phoneme_seq))
        else:
            phoneme_seq = phoneme_seq[:max_length]
        return phoneme_seq


  def pad_to_max_t(self, spectrogram, max_t):

    n_mels, t = spectrogram.shape
    if t < max_t:
        # Pad with zeros
        pad_width = ((0, 0), (0, max_t - t))
        spectrogram = np.pad(spectrogram, pad_width, mode='constant')
    else:
      spectrogram = spectrogram[:, :max_t]


    return spectrogram

  def clean(self, desc):
    # Use regex to remove anything between < and >
    cleaned_text = re.sub(r'<[^>]*>', '', desc)
    return cleaned_text

  def __getitem__(self, idx):

      # SOT = '<|startoftranscript|>'
      # EOT = '<|endoftranscript|>'
      # transcribe = '<|transcribe|>'
      # prev = '<|prev|>'

      #stop token



      spectrogram = self.pad_to_max_t(self.data[idx], self.max_t)
      # probs = round(random.random(),1)
      spectrogram = torch.tensor(spectrogram, dtype=torch.float32)
      original_frames = int((SAMPLING_RATE / hop_length) * self.duration[idx])
      last_frame = min(original_frames, self.max_t) - 1

      stop = torch.zeros((N_MELS, max_t), device=device)
      stop[:, last_frame] = 1.0
      # print(stop
      # if(probs == 0.5):
        # Normalize the spectrogram between -1 and 1
      spectrogram_min = spectrogram.min()
      spectrogram_max = spectrogram.max()

      # spectrogram = spectrogram.unsqueeze(0)  # Shape: (1, n_mels, max_t)
      # prev_text =
      # text = self.clean(self.texts[idx])
      text = self.texts[idx]
      text = text.lower()
      # text = SOT  + 'en' + transcribe +  text + EOT
      # text += '[EOS]'
      # tokenized_text = tokenizer(text, truncation=True, padding='max_length', max_length=block_size, return_tensors='pt')
      text = text.split(' ')
      phenomes = []
      # print(text)
      temp = []
      # tokenized_text = []
      tokenized_text = {}
      spectrograms = {}
      # for batch in range(batch_size):
      for i in range(len(text)):
        phenomes_now = text_to_phonemes(text[i])
        # print(phenomes_now)
        if(phenomes_now == None):
          temp = phonemes_to_indices('[UNK]')
        else:
          temp = phonemes_to_indices(phenomes_now)
        phenomes.extend(temp)
      # phenomes.extend(phonemes_to_indices('[EOS]'))
      #   tokenized_text.append(phenomes)
      #   temp = []
      #   phenomes = []

      # tokenized_text = torch.stack([tokenized_text])
      # print(text)

      # print(phenomes)
      phenomes = self.pad_phoneme_sequence(phenomes, block_size)
      # print(phenomes)
      tokenized = torch.tensor(phenomes, dtype=torch.long)
      # tokenized_text['input_ids'] = tokenized
      # print(tokenized_text.shape)

      epsilon = 1e-8  # To avoid division by zero
      spectrogram = 2 * ((spectrogram - spectrogram_min) / (spectrogram_max - spectrogram_min + epsilon)) - 1

      # tokenized_win_prompt = tokenizer(text, max_length = ModelArgs.block_size, padding='max_length', truncation=True,  return_tensors="pt").to(device)
      # tokenized_text['labels'] = tokenized_text['input_ids'].clone()
      # tokenized_text['labels'][:-1] = tokenized_text['input_ids'][: , 1:]
      # tokenized_text['labels'][: , -1] = phonemes_to_indices('[EOS]')

      # tokenized_text_x = tokenized_text['input_ids'].squeeze(0)
      # tokenized_text_y = tokenized_text['labels'].squeeze(0)

      # print(tokenized_text.shape)
      # print("dataset: ", tokenized_text)
      spectrograms['input_ids'] = spectrogram
      spectrograms['labels'] = spectrogram
      # stop_tokens = tokenized
      return spectrograms, tokenized, stop

In [20]:
def collate_fn(batch):
    text = []
    input_ids_list = []
    labels_list = []
    stop = []
    for spec, text_dict, stop_token in batch:

        # spectrograms.append(spec)

        stop.append(stop_token)
        input_ids_list.append(spec['input_ids'])
        labels_list.append(spec['labels'])
        text.append(text_dict)

    # 3. Stack tensors
    text = torch.stack(text)
    input_ids = torch.stack(input_ids_list)
    labels = torch.stack(labels_list)
    stop = torch.stack(stop)
    # 4. Return in proper format
    return {
        'text': text,
        'input_ids': input_ids,
        'labels': labels,
        "stop_tokens": stop
    }


In [None]:

torch.autograd.set_detect_anomaly(True)  # Add at the start of training

In [24]:


shuffle = True

train_dataset = TTSDataset(train_outputs, train_texts, train_duration)
val_dataset = TTSDataset(val_outputs, val_texts, val_duration)

generator = torch.Generator(device=device)

train_dataloader = DataLoader(

    train_dataset,
    batch_size=batch_size,
    generator=generator,
    shuffle=shuffle,
     drop_last=True,
    collate_fn=collate_fn
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,

    generator=generator,
    drop_last=True ,
    shuffle=False,
    collate_fn = collate_fn
)

In [25]:
#Position embeddings
class SrcPositionEmbeddings(nn.Module):
    def __init__(
        self,
        embeddings_dims = embeddings_dims,
        block_size = block_size
    ):
        super().__init__()

        self.position_embeddings = nn.Parameter(torch.randn(1, block_size, embeddings_dims, device=device), requires_grad=True) #To give positional embeddings to each token of the input text, hence num_embeddings=block_size
        # nn.init.normal_(self.position_embeddings.weight.data, mean=0, std=0.02)

    def forward(self, x):
        return self.position_embeddings

In [26]:
#Position embeddings
class TgTPositionEmbeddings(nn.Module):
    def __init__(
        self,
        embeddings_dims = embeddings_dims,
        block_size = block_size
    ):
        super().__init__()

        self.position_embeddings = nn.Parameter(torch.randn(1, N_MELS, embeddings_dims, device=device), requires_grad=True) #To give positional embeddings to each token of the input text, hence num_embeddings=block_size
        # nn.init.normal_(self.position_embeddings.weight.data, mean=0, std=0.02)

    def forward(self, x):
        return self.position_embeddings

In [27]:
# pos = PositionEmbeddings()
# x = torch.randn(batch_size, block_size, embeddings_dims)
# pos(x)

In [28]:
class PrenetEncoder(nn.Module):
  def __init__(
      self,

  ):
    super().__init__()
    self.device = device
    self.embeds_dims = phenome_embeddings_dims
    self.out = prenet_encoder_embeddings_dims
    self.conv1d_layer1 = nn.Conv1d(in_channels=self.embeds_dims, out_channels=self.out, kernel_size=kernel_size, device=self.device, padding=2)
    self.conv1d_layer2 = nn.Conv1d(in_channels=self.out, out_channels=self.out, kernel_size=kernel_size, device=self.device, padding=2)
    self.conv1d_layer3 = nn.Conv1d(in_channels=self.out, out_channels=self.out, kernel_size=kernel_size, device=self.device, padding=2)
    self.norm = torch.nn.BatchNorm1d(self.out, device=self.device)
    self.proj = nn.Linear(self.out, self.out, device=self.device)
  def forward(self, x):

    x = self.conv1d_layer1(x)
    x = self.norm(x)
    x = torch.nn.functional.relu(x)
    x = self.conv1d_layer2(x)
    x = self.norm(x)
    x = torch.nn.functional.relu(x)
    # x = torch.nn.functional.gelu(x)
    x = self.conv1d_layer3(x)
    # print("x now: ", x.shape)
    x = self.norm(x)
    x = torch.nn.functional.relu(x)
    # print("x shape: ", x.shape)
    x = x.permute(0,2,1)
    x = self.proj(x)
    # print(x.shape)
    return x


In [29]:
class PrenetDecoder(nn.Module):

  def __init__(
      self,

  ):
    super().__init__()
    self.decoder_embeds_dims = embeddings_dims_decoder
    self.device = device
    self.out = phenome_embeddings_dims
    self.linear_layer1 = nn.Linear(in_features=max_t, out_features=self.decoder_embeds_dims, device=self.device)
    self.linear_layer2 = nn.Linear(in_features=self.decoder_embeds_dims, out_features=self.out, device=self.device)
    self.linear_layer3 = nn.Linear(self.out, self.out, device=self.device)
  def forward(self, x):
    x = self.linear_layer1(x)
    x = torch.nn.functional.relu(x)
    x = self.linear_layer2(x)
    x = torch.nn.functional.relu(x)
    x = self.linear_layer3(x)
    return x


In [30]:

#Layer Normalization

class LayerNormalization(nn.Module):
    def __init__(
        self,
        embeddings_dims = embeddings_dims
    ):
        super().__init__()
        self.norm = nn.LayerNorm(normalized_shape=embeddings_dims)
    def forward(self, x):

        return self.norm(x)

In [31]:

#FeedForward Neural Network

class MLPBlock(nn.Module):
    def __init__(
        self,
        dropout = dropout,
        embeddings_size = embeddings_dims,
        # inner_dimensional_states: int = 3072
    ):
        super().__init__()

        self.mlp = nn.Sequential(
            nn.Linear(device=device, in_features=embeddings_size, out_features= 4 * embeddings_dims),
            nn.GELU(),
            nn.Linear(device=device, in_features= 4 * embeddings_dims, out_features=embeddings_size),
            nn.Dropout(p = dropout)
        )

    def forward(self, x):
        # mlp_weights_init = self.mlp.apply(weights_init)
        return self.mlp(x)

In [32]:


class MaskedAttentionHead(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
    ):
        super().__init__()
        # print(embeddings_dims)
        self.head_size = embeddings_dims // no_of_heads
        self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=device, bias=False)
        self.keys = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,device=device, bias=False)
        self.values = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=device,bias=False)
        self.dropout = nn.Dropout(p = attn_dropout)


    def forward(self, x):
        # print(x.shape)
        batch, block_size, embd_dims = x.shape
        k = self.keys(x)
        q = self.query(x)
        v = self.values(x)
        masked_table = torch.tril(torch.ones(block_size, block_size, device=device))
        weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5)
        masked_values = weights.masked_fill(masked_table[: block_size, : block_size] == 0, float('-inf'))
        weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
        weights_normalized = self.dropout(weights_normalized)
        out = weights_normalized @ v
        return out

In [33]:



class MaskedMHA(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
    ):
        super().__init__()
        self.heads = nn.ModuleList([MaskedAttentionHead(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads) for _ in range(no_of_heads)])
        self.dropout = nn.Dropout(p = attn_dropout)
        self.linear = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=device, bias=False) # 12 (no of heads) * (batch_size) 64 = 768 -> gives out the text embeddings

    def forward(self, x):
        concat = torch.cat([head(x) for head in self.heads], dim=-1)
        linear_layer = self.linear(concat)
        out = self.dropout(linear_layer)
        return out

In [34]:

#Single Attention Head

class CrossAttentionHead(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
    ):
        super().__init__()
        self.head_size = embeddings_dims // no_of_heads
        self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=device, bias=False)
        self.keys = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,device=device, bias=False)
        self.values = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=device,bias=False)
        self.dropout = nn.Dropout(p = attn_dropout)


    def forward(self, query, key, value, mask=None):


        batch, block_size, embd_dims = query.shape
        q = self.query(query)
        k = self.keys(key)
        v = self.values(value)
        # masked_table = torch.tril(torch.ones(block_size, block_size, device=device))
        # weights = query @ torch.transpose(key, dim0=-2, dim1=-1) * (key.shape[-1] ** -0.5)
        # if(mask != None):
        #     mask = mask.unsqueeze(1)
        #     masked_values = weights.masked_fill(mask == 0, float('-inf'))
        #     weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
        #     # weights_normalized = self.dropout(weights_normalized)
        #     out = weights_normalized @ value
        #     out = self.dropout(out)
        #     return out
        # else:
        #     weights_normalized = nn.functional.softmax(weights, dim=-1) #Normalize along the embeddings dimension for all the tokens
        #     # weights_normalized = self.dropout(weights_normalized)
        #     out = weights_normalized @ value
        #     out = self.dropout(out)
        #     return out

        masked_table = torch.tril(torch.ones(block_size, block_size, device=device))
        weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5)
        masked_values = weights.masked_fill(masked_table[: block_size, : block_size] == 0, float('-inf'))
        weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
        weights_normalized = self.dropout(weights_normalized)
        out = weights_normalized @ v
        return out

In [35]:
#Single Attention Head

class FullAttentionHead(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
    ):
        super().__init__()
        self.head_size = embeddings_dims // no_of_heads
        self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=device, bias=False)
        self.keys = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,device=device, bias=False)
        self.values = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=device,bias=False)
        self.dropout = nn.Dropout(p = attn_dropout)


    def forward(self, x, mask=None):
        # batch, block_size, embd_dims = x.shape
        k = self.keys(x)
        q = self.query(x)
        v = self.values(x)
        # masked_table = torch.tril(torch.ones(block_size, block_size, device=device))
        weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5)
        if(mask != None):
            mask = mask.unsqueeze(1)
            masked_values = weights.masked_fill(mask == 0, float('-inf'))
            weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
            # weights_normalized = self.dropout(weights_normalized)
            out = weights_normalized @ v
            out = self.dropout(out)
            return out
        else:
            weights_normalized = nn.functional.softmax(weights, dim=-1) #Normalize along the embeddings dimension for all the tokens
            # weights_normalized = self.dropout(weights_normalized)
            out = weights_normalized @ v
            out = self.dropout(out)
            return out

In [36]:

class FullMHA(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
    ):
        super().__init__()
        self.heads = nn.ModuleList([FullAttentionHead(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads) for _ in range(no_of_heads)])
        self.dropout = nn.Dropout(p = attn_dropout)
        self.linear = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=device, bias=False) # 12 (no of heads) * (batch_size) 64 = 768 -> gives out the text embeddings

    def forward(self, x, mask=None):
        concat = torch.cat([head(x, mask) for head in self.heads], dim=-1)
        linear_layer = self.linear(concat)
        out = self.dropout(linear_layer)
        return out

In [37]:


class CrossMHA(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
    ):
        super().__init__()
        self.heads = nn.ModuleList([CrossAttentionHead(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads) for _ in range(no_of_heads)])
        self.dropout = nn.Dropout(p = attn_dropout)
        self.linear = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=device, bias=False)

    def forward(self, value, key, x, mask=None):
        concat = torch.cat([head(x, key, value,  mask) for head in self.heads], dim=-1)
        linear_layer = self.linear(concat)
        out = self.dropout(linear_layer)
        return out

In [38]:
# Decoder Block

class TransformerDecoderBlock(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
        dropout = dropout,
        # vocab_size = vocab_size
    ):
        super().__init__()

        self.cross = CrossMHA(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads)
        self.masked = MaskedMHA(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads)
        self.layer_norm1 = LayerNormalization(embeddings_dims)
        self.layer_norm2 = LayerNormalization(embeddings_dims)
        # self.layer_norm3 = LayerNormalization(embeddings_dims=embeddings_dims)
        self.layer_norm4 = LayerNormalization(embeddings_dims)
        self.mlp_block = MLPBlock(dropout=dropout, embeddings_size=embeddings_dims)

    def forward(self, key, value, x, mask=None):
        x = self.layer_norm1(x + self.masked(x)) #Very important step -> Layer Norm on input and then passes it to the subsequent blocks
        # print(x.shape)
        x = self.layer_norm2(x + self.cross(value, key, x, mask)) #Very important step
        # print(x.shape)
        # x = x + self.mha(self.layer_norm1(x))  #Very important step -> Layer Norm on input and then passes it to the subsequent blocks
        x = self.layer_norm4(x + self.mlp_block(x)) #Very important step
        # print(x.shape)

        return x

In [39]:
class PostNet(nn.Module):
  def __init__(self):

    super().__init__()

    self.out = embeddings_dims
    self.device = device
    self.conv_layer1 = nn.Conv1d(self.out, self.out, kernel_size=kernel_size, device=self.device, padding=2)
    self.conv_layer2 = nn.Conv1d(self.out, self.out, kernel_size=kernel_size, device=self.device, padding=2)
    self.conv_layer3 = nn.Conv1d(self.out, self.out, kernel_size=kernel_size, device=self.device, padding=2)
    self.conv_layer4 = nn.Conv1d(self.out, self.out, kernel_size=kernel_size, device=self.device, padding=2)
    self.conv_layer5 = nn.Conv1d(self.out, self.out, kernel_size=kernel_size, device=self.device, padding=2)
    self.norm = torch.nn.BatchNorm1d(self.out)
    # self.norm2 = torch.nn.BatchNorm1d(N_MELS)
  def forward(self,x):
    # print("here: ", x.shape)
    x = x.transpose(1,2).contiguous()
    x = self.conv_layer1(x)
    x = self.norm(x)
    x = torch.nn.functional.tanh(x)
    x = self.conv_layer2(x)
    x = self.norm(x)
    x = torch.nn.functional.tanh(x)
    x = self.conv_layer3(x)
    x = self.norm(x)
    x = torch.nn.functional.tanh(x)
    x = self.conv_layer4(x)
    x = self.norm(x)
    x = torch.nn.functional.tanh(x)
    x = self.conv_layer5(x)
    x = self.norm(x)
    x = x.transpose(1,2).contiguous()
    return x


In [40]:
# Decoder Block

class DecoderModel(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
        block_size = block_size,
        dropout = dropout,
        no_of_decoder_layers = no_of_decoder_layers,
        # vocab_size = vocab_size
    ):
        super().__init__()




        # self.tgt_text_embds = TgtTextEmbeddings(vocab_size=tgt_vocab_size, embeddings_dims=embeddings_dims)
        # self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=tgt_vocab_size, device=device, bias=False) # Takes in logits of dimensions- embeds_dims and converts it into dimension of vocab_size (logits in range of vocab_size)
        # self.layer_norm = LayerNormalization(embeddings_dims=embeddings_dims)
        self.decoder_layers = nn.ModuleList([TransformerDecoderBlock(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads, dropout=dropout) for _ in range(no_of_decoder_layers)])
        self.apply(self._init_weights)
        # self.positional_embeddings_tgt = nn.Parameter(torch.randn(1, block_size, embeddings_dims, device=device), requires_grad=True) #To give positional embeddings to each token of the input text, hence num_embeddings=block_size
        self.positional_embeddings_tgt = TgTPositionEmbeddings()
        self.scaled_factor = nn.Parameter(torch.ones(1, N_MELS, embeddings_dims), requires_grad=True)
        # torch.nn.init.normal_(self.positional_embeddings_tgt, mean=0.0, std=0.02)

        # out = self.decoder_layers(query, key, x)
        # Loop through each decoder layer
    def _init_weights(self, module):  #Weight Initialization
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, key, value, x, mask):
        # x = self.tgt_text_embds(x)
        # print(x.shape)
        x = x + self.scaled_factor * self.positional_embeddings_tgt(x)
        # print(x.shape)
        for decoder_layer in self.decoder_layers:
            x = decoder_layer(key, value, x, mask)
        # x = self.layer_norm(x)

        return x

In [41]:

#Encoder

In [42]:





class TransformerEncoderBlock(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
        dropout = dropout,
        mask=None
    ):
        super().__init__()

        self.mha = FullMHA(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads)
        self.layer_norm1 = LayerNormalization(embeddings_dims)
        self.layer_norm2 = LayerNormalization(embeddings_dims)
        self.mlp_block = MLPBlock(dropout=dropout, embeddings_size=embeddings_dims)

    def forward(self, x, mask=None):
        x = self.layer_norm1(x + self.mha(x, mask))
        x = self.layer_norm2(x + self.mlp_block(x))

        return x

In [43]:




class EncoderModel(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
        block_size = block_size,
        dropout = dropout,
        no_of_decoder_layers = no_of_decoder_layers,
        # vocab_size = vocab_size
    ):
        super().__init__()

        # self.positional_embeddings_src = nn.Parameter(torch.randn(1, block_size, embeddings_dims, device=device), requires_grad=True) #To give positional embeddings to each token of the input text, hence num_embeddings=block_size
        self.prenet_enc = PrenetEncoder()
        # self.pos_embeds = nn.Parameter(torch.randn(1, block_size, embeddings_dims, device=device), requires_grad=True)
        self.trainable_factor = nn.Parameter(torch.ones(1, block_size, embeddings_dims, device=device), requires_grad=True)
        # self.conv1 = nn.Conv1d(in_channels=n_channels, out_channels=embeddings_dims, kernel_size=kernel_size, device=device, padding=1)
        # self.conv2 = nn.Conv1d(in_channels=embeddings_dims, out_channels=embeddings_dims, kernel_size=kernel_size, device=device, padding=1)

        self.positional_embeddings_src = SrcPositionEmbeddings()
        self.src_text_embeds = nn.Embedding(num_embeddings=src_vocab_size, embedding_dim=embeddings_dims, device=device)
        # self.src_text_embeds = SrcTextEmbeddings(vocab_size=src_vocab_size, embeddings_dims=embeddings_dims)
        self.encoder_layers = nn.ModuleList([TransformerEncoderBlock(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads, dropout=dropout) for _ in range(no_of_decoder_layers)])
        self.apply(self._init_weights)

    def _init_weights(self, module):  #Weight Initialization
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x, mask):

        # x = self.conv1(x)
        # x = torch.nn.functional.gelu(x)
        # x = self.conv2(x)
        # x = torch.nn.functional.gelu(x)
        # print(x.shape)
        x = self.src_text_embeds(x)
        # print(self.positional_embeddings_src.shape)
        x = x.transpose(1, 2).contiguous()
        x = self.prenet_enc(x)
        # x = x.permute(0, 2, 1)
        # print(x.shape)
        # print(self.positional_embeddings_src(x).shape)
        x = x + self.trainable_factor * self.positional_embeddings_src(x)
        # print(x)
        # print(x.shape)
        # Loop through each encoder layer
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x, mask)
        return x

In [44]:



class TTS(nn.Module):
    def __init__(
        self,

    ):
        super().__init__()

        self.encoder = EncoderModel()
        self.decoder = DecoderModel()
        self.postnet = PostNet()
        # self.pos = PositionalEmbeddings()
        # self.tgt_text_embds = TgtTextEmbeddings(vocab_size=tgt_vocab_size, embeddings_dims=embeddings_dims)
        self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=device, bias=False) # Takes in logits of dimensions- embeds_dims and converts it into dimension of vocab_size (logits in range of vocab_size)
        # self.src_text_embeds = SrcTextEmbeddings(vocab_size=src_vocab_size, embeddings_dims=embeddings_dims)
        self.stop_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=device, bias=False)
        self.prenet_dec = PrenetDecoder()

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # x = self.src_text_embeds(src)
        x = self.encoder(src, src_mask)
        # print(tgt.shape)
        y = self.prenet_dec(tgt)
        # y = self.tgt_text_embds(tgt)
        # print(x.shape)
        y = self.decoder(x, x, y, tgt_mask)
        # print(y.shape)

        out1 = self.linear_layer(y)
        out2 = self.postnet(out1)
        out2 += out1
        stop = self.stop_layer(y)
        # stop = torch.nn.functional.sigmoid(stop)
        return out1, out2, stop

In [45]:
#Instantiating the model
model = TTS()
# model = torch.compile(model)
# model = model.to(device)
model = model.to(device)


In [46]:

# print(text.shape)

In [None]:



!pip install torchinfo
from torchinfo import summary

data = next(iter(train_dataloader))
# print(data)
# tgt_mask = torch.randint(1, tgt_vocab_size, (batch_size, block_size)).to(device)  #
spec1 = data['input_ids'].to(device)
text = data['text'].to(device)

summary(model=model,
        input_data=(text, spec1),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

In [48]:

# # Optimizer setup and scheduler steup
# out = {"Train": None, "val": None}
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr)


loss_fn = nn.MSELoss()

In [49]:



torch.set_float32_matmul_precision('high')

scaler = torch.amp.GradScaler(enabled=True)

In [50]:


def _save_snapshot(model, optimizer, scheduler, epoch, step):
    snapshot = {
        "MODEL_STATE": model.state_dict(),
        "OPTIMIZER_STATE": optimizer.state_dict(),
        # "SCHEDULER_STATE": scheduler.state_dict(),
        "EPOCHS_RUN": epoch,
        "STEP_RUN": step
    }
    torch.save(snapshot, f"snapshot_{step}.pt")
    print(f"Epoch: {epoch} | Step: {step} | Snapshot saved.")

In [51]:
# !pip install torchtriton

In [52]:
save_chechpoint_iter = 50
total_iters = 20000
eval_iters = 50
eval_check = 100
warmup_iters = 2048
min_lr = 0.1 * max_lr
lr_decay_iters = 20000
total_batch_size = 524288
micro_batch_size = batch_size
gradient_accumulation_steps = 32

In [None]:

model.eval()
world_size = torch.cuda.device_count()
@torch.inference_mode()
def estimate_loss(val_loader, val_iterator, device):
    out = {}
    # train_loader = prepare_dataset('train', ModelArgs.batch_size)

    # val_loader_iterator = iter(val_loader)
    loader = None
    epoch_loss = None
    epoch_losses = []
    # print("Starting the eval...")
    for split in ['val']:
        print(f"Starting with {split} evaluation...")
        # losses = torch.zeros(ModelArgs.val_epochs)
        # if(split == 'train'):
        #         loader = train_loader
        # if(split == 'val'):
        #         loader = val_loader
        for step in range(eval_check):
            try:
                data = next(val_iterator)
            except StopIteration:
                val_loader_iterator = iter(val_loader)
                data = next(val_loader_iterator)

            # tgt_mask = torch.randint(1, tgt_vocab_size, (batch_size, block_size)).to(device)  #
            total_loss = 0
            # loader.sampler.set_epoch(step)
            total_batches = 0
            # batch = next(val_loader_iterator)
            # for batch in loader:  # Loop through DataLoader batches
            # idx = batch['input_ids']
            # targets = batch['labels']

            data['text'] = data['text'].to(device)
            data['input_ids'] = data['input_ids'].to(device)
            data['labels'] = data['labels'].to(device)
            idx = data['text']
            spec = data['input_ids']
            y = data['labels']
            with torch.autocast(device_type=device, dtype=torch.float16):

                # pre, post, stop_token = model(idx, spec)
                pre, post, _stop = model(idx, spec)
                # batch_size, block_size, embeddings_dims = stop_token.shape
                # print("y: ", y.shape)
                # print("Pre: ", pre.shape)
                # print("post: ", post.shape)
                # print("stop: ", _stop.shape)

                # print(logits.shape)
                # print(targets)
                # stop_token = stop_token.view(batch_size*block_size, embeddings_dims)
                # # print("OK")
                # targets = idx.view(batch_size * block_size)

                # # print("OK2")

                pre_loss = nn.functional.mse_loss(pre, y)
                post_loss = nn.functional.mse_loss(post, y)

                # stop_token = stop_token.view(batch_size*block_size, embeddings_dims)
                # print("OK")
                # targets = idx.view(batch_size * block_size)

                # print("OK2")

                # pre_loss = nn.functional.mse_loss(pre, y)

                stop_loss = nn.functional.binary_cross_entropy_with_logits( _stop, stop_token, pos_weight = torch.tensor([8.0]))
                # print(pre_loss)
                # print(post_loss)
                # print(stop_loss)
                loss = pre_loss + post_loss + stop_loss
                total_loss += loss.item()
                total_batches += 1

        # Compute mean loss for this epoch
        epoch_loss = total_loss / total_batches if total_batches > 0 else 0.0
        epoch_losses.append(epoch_loss)

            # print(f"Epoch {epoch + 1}/{ModelArgs.val_epochs}: Loss = {epoch_loss:.4f}")

        # Compute mean loss across all evaluation epochs
        out[split] = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0.0
        epoch_loss = None
        epoch_losses = []

    model.train()
    return out

# model = model.to(rank)
model.train()
count = 0

# train_dataloader = prepare_dataset('train', device, ModelArgs.batch_size)
# val_loader= prepare_dataset('val', device, ModelArgs.batch_size)
# for step in tqdm(range(total_iters)):
# for epoch in range(ModelArgs.epochs):
    # torch.cuda.synchronize()

# train_dataloader.sampler.set_epoch(epoch)

# val_loader.sampler.set_epoch(epoch)
print("Loaders ready both")
epochs = epochs

# train_step_iterator = range(len(train_dataloader))
# if device == 0:  # Only create progress bar on rank 0
#   train_step_iterator = tqdm(train_step_iterator, desc="Training Progress", position=0, leave=True)

    # Print progress on rank 0
train_loader_length = 0
train_data_iterator = iter(train_dataloader)
val_data_iterator = iter(val_dataloader)
token_count = 0
if(device == 0):
    train_loader_length = len(train_dataloader)
    # print("Total batches: ", train_loader_length)
# print("Length of : ", len(train_dataloader))
# print("Length of val: ", len(val_loader))
# for  step, batch in enumerate(train_dataloader):

In [54]:

def find_unused_parameters(model):
    unused = []
    for name, param in model.named_parameters():
        if param.grad is None:

            unused.append(name)
    return unused

In [55]:


import math
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return max_lr * (it + 1) / (warmup_iters + 1)
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (max_lr - min_lr)

In [None]:

#data = next

data = next(iter(train_dataloader))
for key, value in data.items():
    if(key == 'stop_tokens'):
      print(value)
      break

# print(data)
# tgt_mask = torch.randint(1, tgt_vocab_size, (batch_size, block_size)).to(device)  #

In [None]:
model.train()
train_losses =  torch.zeros(len(train_dataloader))
val_losses = torch.zeros(len(val_dataloader))
wandb.init(
    project='TTS-From-Scratch'
)
step = 0
for step in tqdm(range(total_iters)):
        # print("Dataloader things: ", batch)
        # print("Total batches: ", len(train_dataloader))


        # if(device == 0):
            # if(step % 100 == 0):
        #     if(step == train_loader_length):
        #       break
        print("Step : ", step, "/", total_iters)
        print('Total batches: ', len(train_dataloader))
        print("Total gradient accumulation steps: ", gradient_accumulation_steps)
                # print("Total tokens processed: ", token_count)

        # all_gpus_avg_train_loss = None
        # all_gpus_avg_val_loss = None
        # every once in a while evaluate the loss on train and val sets
        if (step  % eval_iters == 0 and step != 0) or step == total_iters - 1:
            losses = estimate_loss( val_dataloader, val_data_iterator, 'cuda')
            # avg_train_loss = losses['train']
            avg_val_loss = losses['val']
            # print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
            # if device == 0:  # Only print on main process
            print(f"[GPU {device}] | Step: {step} / {total_iters} | Val Loss: {losses['val']:.4f}")
            # print(f"[GPU {device}] | Epoch {epoch}/{ModelArgs.epochs}| |Step: {step} | Train Loss: {losses['train']:.4f}")
                # print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
                # Log training loss more frequently
                # Aggregate average loss across all GPUs
            # avg_train_loss = torch.Tensor([losses['train']]).to(device)
            avg_val_loss = torch.Tensor([losses['val']]).to(device)
            # torch.distributed.reduce(avg_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
            # torch.distributed.reduce(avg_val_loss, dst=0, op=torch.distributed.ReduceOp.SUM)

            # if device == 0:
                # all_gpus_avg_train_loss = avg_train_loss / world_size
                # print(f"All_GPUs_Train_losses: {all_gpus_avg_train_loss.item():.4f}")
            all_gpus_avg_val_loss = avg_val_loss / world_size
            print(f"All_GPUs_Val_losses: {all_gpus_avg_val_loss.item():.4f}")

            # if device == 0:

                # writer.add_scalar("All_GPUs_Train_losses", all_gpus_avg_train_loss.item(), global_step=step)
                # writer.add_scalar("All_GPUs_Val_losses", all_gpus_avg_val_loss.item(), global_step=step)
                # writer.add_scalar("training_step_loss", losses['train'], global_step=step)
                # writer.add_scalar("val_step_loss", losses['val'], global_step=step)
                # writer.add_scalar("GPU", device, global_step=step)
                # writer.add_scalar("Epoch", epoch, global_step=step)

            wandb.log({
                    # "Learning Rate": optimizer.param_groups[0]['lr'],
                    # "All_GPUs_Train_losses": all_gpus_avg_train_loss,
                    "All_GPUs_Val_losses": all_gpus_avg_val_loss,
                    # "training_step_loss": losses['train'],
                    "val_step_loss": losses['val'],
                    # "Step": step,
                    # "Epoch": epoch
                })



        #Loading a checkpoint
        # if(os.path.exists('snapshot.pt')):
        #    model, optimizer =  _load_snapshot(model=model, optimizer=optimizer, epoch=epoch, step=step, snapshot_path='snapshot.pt')

        # if(step % save_chechpoint_iter == 0 and device == 0 and step != 0):

        #     _save_snapshot(epoch=epoch, model=model, optimizer=optimizer, step=step)

        if step % save_chechpoint_iter == 0 and device == 0 and step != 0:
            print(f"Saving the model checkpoint for step: {step}")
            _save_snapshot(model, optimizer, None, None, step)

        accumulated_loss = 0.0


        optimizer.zero_grad(set_to_none=True)
        for micro_step in range(gradient_accumulation_steps):
            try:
                data = next(train_data_iterator)
            except StopIteration:
                train_data_iterator = iter(train_dataloader)
                data = next(train_data_iterator)

            data['text'] = data['text'].to(device)
            data['input_ids'] = data['input_ids'].to(device)
            data['labels'] = data['labels'].to(device)
            data['stop_token'] = data['stop_tokens'].to(device)
            idx = data['text']
            spec = data['input_ids']
            y = data['labels']
            stop_token = data['stop_token']
            # print(
            # tgt_mask = torch.randint(1, tgt_vocab_size, (batch_size, block_size)).to(device)  #
            # print(batch)
            # batch = next(train_data_iterator)
            # print(batch)
            # batch = {k: v.to(self.local_rank) for k, v in batch.items()}
            # idx = batch['input_ids'].to(device)
            # idx, targets = get_batch(split='train')
            # print(f"Starting the train step: {step}...")
            # for idx, targets in train_loader:
            # idx, targets = next(iter(train_loader))

            # print("Idx: ", idx)
            # print("Targets: ", targets)

            # idx = idx.to(device)
            # print("Idx: ", idx)
            # print("Targets: ", targets)
            # targets = batch['labels'].to(device)
            # token_count += len(idx)
            with torch.autocast(device_type=device, dtype=torch.float16):
                pre, post, _stop = model(idx, spec)
                # batch_size, block_size, embeddings_dims = stop_token.shape
                # print("y: ", y.shape)
                # print("Pre: ", pre.shape)
                # print("post: ", post.shape)
                # print("stop: ", _stop.shape)

                # print(logits.shape)
                # print(targets)
                # stop_token = stop_token.view(batch_size*block_size, embeddings_dims)
                # # print("OK")
                # targets = idx.view(batch_size * block_size)

                # # print("OK2")

                pre_loss = nn.functional.mse_loss(pre, y)
                post_loss = nn.functional.mse_loss(post, y)

                # stop_token = stop_token.view(batch_size*block_size, embeddings_dims)
                # print("OK")
                # targets = idx.view(batch_size * block_size)

                # print("OK2")

                # pre_loss = nn.functional.mse_loss(pre, y)

                stop_loss = nn.functional.binary_cross_entropy_with_logits( _stop, stop_token, pos_weight = torch.tensor([8.0]))
                # print(pre_loss)
                # print(post_loss)
                # print(stop_loss)
                loss = pre_loss + post_loss + stop_loss
                loss = loss / gradient_accumulation_steps #IDK why div is done here specifically? Maybe think of it in terms of a very big batch being processed and there is need for equal important of each mini batch for the overall big batch

                accumulated_loss += loss.detach()
                # print(accumulated_loss)
            model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) # so that we dont synchronize the gradient everytime across the GPU devices
            scaler.scale(loss).backward()
            # print("loss: ", loss.item())
                # Check for unused parameters
            unused_params = find_unused_parameters(model)
            if unused_params:
                print(f"Unused parameters: {unused_params}")
        # break

            # if(device == 0):
            if(micro_step % 10 == 0):
            #     if(step == train_loader_length):
            #       break

                    print("Micro Batch : ", micro_step)
                    print("Step : ", step, "/", total_iters)
                    print('Total batches: ', len(train_dataloader))
                    print("Total gradient accumulation steps: ", gradient_accumulation_steps)
                    print("Total tokens processed: ", token_count)
            # count += 1

        lr = get_lr(step)
        for params in optimizer.param_groups:
            params['lr'] = lr



        # Compute gradient norms before clipping
        if(clip != 0.0):
            scaler.unscale_(optimizer) #To avoid underflow
            total_norm_before = torch.norm(
                torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
            )

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip)

            # Compute gradient norms after clipping
            total_norm_after = torch.norm(
                torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
            )

            if(device  == 0 and step !=0):
                print(f"Gradient Norm Before Clipping: {total_norm_before.item():.4f}")
                print(f"Gradient Norm After Clipping: {total_norm_after.item():.4f}")

        scaler.step(optimizer)
        scaler.update()

        # optimizer.step()
        # new_scheduler.step()
        # print(accumulated_loss)
        # torch.cuda.synchronize()
        # torch.distributed.reduce(loss, dst=0, op=torch.distributed.ReduceOp.SUM)
        # if(device == 0):
        wandb.log({
                    "Learning Rate": lr,
                    "All_GPUs_Train_losses": accumulated_loss.item(),
                    # "All_GPUs_Val_losses": all_gpus_avg_val_loss,
                    # "training_step_loss": losses['train'],
                    # "val_step_loss": losses['val'],
                    "Step": step,
                    # "Epoch": epoch

                })


        # model.train()
        # wandb.log({
        #   "Train Loss": train_losses.mean(),
        #   "Val Loss": val_losses.mean(),
        #   # "epoch": epoch
        # })
        # print("Epoch: ", epoch, "|", "Train Loss: ", train_losses.mean(),  "|", "Val Loss: ", val_losses.mean())