In [1]:
import torch
import torchaudio
from torch.utils.data import DataLoader
from pytorch_metric_learning import miners, losses
from dataloader import get_IEMOCAPAudio_dataloader
import numpy as np
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Prepare your datasets and dataloaders
train_loader = get_IEMOCAPAudio_dataloader("D:/bert-based-selfalign\IEMOCAP_wav2vec\IEMOCAP_6_class/train.pkl", train= True, batch_size=4)
val_loader = get_IEMOCAPAudio_dataloader("D:/bert-based-selfalign\IEMOCAP_wav2vec\IEMOCAP_6_class/val.pkl", train= False, batch_size=8)

In [3]:
for i in train_loader:
    wav, label = i
    print(wav)

tensor([[ 0.0016, -0.0006, -0.0036,  ..., -0.0023, -0.0019, -0.0009],
        [-0.0034, -0.0037, -0.0033,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0007,  0.0007,  0.0006,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0033, -0.0033, -0.0028,  ...,  0.0000,  0.0000,  0.0000]])
tensor([[ 0.0015,  0.0054,  0.0058,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0309,  0.0305,  0.0306,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0063,  0.0063,  0.0063,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0004, -0.0006, -0.0008,  ...,  0.0012,  0.0012,  0.0011]])
tensor([[-0.0021, -0.0009, -0.0004,  ..., -0.0023, -0.0021, -0.0020],
        [ 0.0015,  0.0012,  0.0010,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0012,  0.0012,  0.0013,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0005,  0.0004,  0.0002,  ...,  0.0000,  0.0000,  0.0000]])
tensor([[ 6.1035e-05, -3.3569e-04, -6.7139e-04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00, -1.8311e-04,  3.0518e-05,  ..., 

In [4]:
# Load the Wav2Vec model
model = torchaudio.pipelines.WAV2VEC2_BASE.get_model()
model.to(device)

Wav2Vec2Model(
  (feature_extractor): FeatureExtractor(
    (conv_layers): ModuleList(
      (0): ConvLayerBlock(
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
      )
      (1-4): 4 x ConvLayerBlock(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
      )
      (5-6): 2 x ConvLayerBlock(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
      )
    )
  )
  (encoder): Encoder(
    (feature_projection): FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (pos_conv_embed): ConvolutionalPositionalEmbedding(
        (conv): ParametrizedConv1d(
          768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16
          (parametriza

In [5]:
# Miner and Loss
miner = miners.MultiSimilarityMiner()
criterion = losses.MultiSimilarityLoss(1, 60, 0.5)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

In [6]:
def train(dataloader,loss_func):
    model.train()
    losses = []
    progress_bar = tqdm(dataloader, leave=True)
    for batch in progress_bar:
        optimizer.zero_grad()
        
        waveforms, labels = batch
        waveforms = waveforms.to(device)
        labels = labels.to(device)
        # Forward pass
        outputs = model(waveforms)[0][:,0]
        # Mining
        hard_pairs = miner(outputs, labels)
        
        # Compute loss
        loss = loss_func(outputs, labels, hard_pairs)
        
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        avg_loss = np.mean(losses) if losses else 0.0
        progress_bar.set_description(f"[Train] Avg Loss: {avg_loss:.4f}, Loss: {loss.item():.4f}")
        
    return np.mean(losses)

def validation(dataloader,loss_func):
    model.eval()
    losses = []
    progress_bar = tqdm(dataloader, total=len(dataloader), leave=True, desc="Evaluating")
    with torch.no_grad():
        for batch in progress_bar:
            waveforms, labels = batch
            waveforms = waveforms.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(waveforms)[0][:,0]

            # Mining
            hard_pairs = miner(outputs, labels)
            
            # Compute loss
            loss = loss_func(outputs, labels, hard_pairs)
            
            losses.append(loss.item())
            avg_loss = np.mean(losses) if losses else 0.0
            progress_bar.set_description(f"[Val] Avg Loss: {avg_loss:.4f}, Loss: {loss.item():.4f}")
        
    return np.mean(losses)

In [7]:
# Training loop with validation
num_epochs = 10

best_val_loss = float('inf')

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    train_loss = train(train_loader, loss_func = criterion)
    print(f'Train Loss :{train_loss}')
    
    # Validation Phase
    dev_loss = validation(val_loader, loss_func = criterion)
    print(f'Validation Loss :{dev_loss}')
    # Save the model if the validation loss has improved
    if dev_loss  < best_val_loss:
        print(f"Validation loss improved from {best_val_loss} to {dev_loss}. Saving model to model_best.pth")
        best_val_loss = dev_loss 
        torch.save(model.state_dict(), "SA W2V model_best.pth")

Epoch 1/10


  0%|          | 0/1328 [00:00<?, ?it/s]

torch.Size([4, 768])


[Train] Avg Loss: 0.4644, Loss: 0.4644:   0%|          | 1/1328 [00:04<1:33:52,  4.24s/it]

torch.Size([4, 768])


[Train] Avg Loss: 0.4473, Loss: 0.4301:   0%|          | 2/1328 [00:05<52:26,  2.37s/it]  

torch.Size([4, 768])


[Train] Avg Loss: 0.2982, Loss: 0.0000:   0%|          | 3/1328 [00:06<37:23,  1.69s/it]

torch.Size([4, 768])


[Train] Avg Loss: 0.3530, Loss: 0.5177:   0%|          | 4/1328 [00:11<1:08:56,  3.12s/it]

torch.Size([4, 768])


[Train] Avg Loss: 0.2824, Loss: 0.0000:   0%|          | 5/1328 [00:11<47:45,  2.17s/it]  

torch.Size([4, 768])


[Train] Avg Loss: 0.2354, Loss: 0.0000:   0%|          | 6/1328 [00:15<58:57,  2.68s/it]

torch.Size([4, 768])


[Train] Avg Loss: 0.2331, Loss: 0.2193:   1%|          | 7/1328 [00:16<44:15,  2.01s/it]

torch.Size([4, 768])


[Train] Avg Loss: 0.2039, Loss: 0.0000:   1%|          | 8/1328 [00:16<34:41,  1.58s/it]

torch.Size([4, 768])


[Train] Avg Loss: 0.2845, Loss: 0.9290:   1%|          | 9/1328 [00:17<26:48,  1.22s/it]

torch.Size([4, 768])


[Train] Avg Loss: 0.2845, Loss: 0.9290:   1%|          | 9/1328 [00:22<55:04,  2.51s/it]


KeyboardInterrupt: 