In [None]:
import os
import sys

lib_path = os.path.abspath("").replace("notebooks", "src")
sys.path.append(lib_path)

In [2]:
import datetime
import logging
import random

import numpy as np
import torch
from torch import nn, optim
from transformers import BertTokenizer, RobertaTokenizer,AutoTokenizer

from configs.base import Config
from data.dataloader import build_train_test_dataset
from models import losses, networks
from trainer import Trainer
from utils.configs import get_options
from utils.torch.callbacks import CheckpointsCallback

SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
device

device(type='cuda')

In [4]:
opt = get_options(f"{lib_path}/configs/bert_vggish.py")
logging.info("Initializing model...")
# Model
try:
    network = getattr(networks, opt.model_type)(
        num_classes=opt.num_classes,
        num_attention_head=opt.num_attention_head,
        dropout=opt.dropout,
        text_encoder_type=opt.text_encoder_type,
        text_encoder_dim=opt.text_encoder_dim,
        text_unfreeze=opt.text_unfreeze,
        audio_encoder_type=opt.audio_encoder_type,
        audio_encoder_dim=opt.audio_encoder_dim,
        audio_unfreeze=opt.audio_unfreeze,
        audio_norm_type=opt.audio_norm_type,
        fusion_head_output_type=opt.fusion_head_output_type,
    )
    network.to(device)
except AttributeError:
    raise NotImplementedError("Model {} is not implemented".format(opt.model_type))

2023-10-13 09:01:36,810 - root - INFO - Initializing model...
Some weights of RobertaModel were not initialized from the model checkpoint at SamLowe/roberta-base-go_emotions and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
network

MMSERA(
  (text_encoder): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNo

In [6]:

logging.info("Initializing checkpoint directory and dataset...")
if opt.text_encoder_type == "bert":
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
elif opt.text_encoder_type == "roberta":
    tokenizer = AutoTokenizer.from_pretrained("SamLowe/roberta-base-go_emotions")
else:
    raise NotImplementedError("Tokenizer {} is not implemented".format(opt.text_encoder_type))

# Preapre the checkpoint directory
opt.checkpoint_dir = checkpoint_dir = os.path.join(
    os.path.abspath(opt.checkpoint_dir),
    opt.name,
    datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
)
log_dir = os.path.join(checkpoint_dir, "logs")
weight_dir = os.path.join(checkpoint_dir, "weights")
os.makedirs(log_dir, exist_ok=True)
os.makedirs(weight_dir, exist_ok=True)
opt.save(opt)

2023-10-13 09:01:44,376 - root - INFO - Initializing checkpoint directory and dataset...
2023-10-13 09:01:44,992 - root - INFO - 
             audio_encoder_dim: 128                                     
            audio_encoder_type: vggish                                  
              audio_max_length: 50                                      
               audio_norm_type: layer_norm                              
                audio_unfreeze: False                                   
                    batch_size: 4                                       
                checkpoint_dir: d:\MMSERA\notebooks\checkpoints\bert_vggish_MMSERA\20231013-090144
                     data_root: D:/MELD/MELD                            
                       dropout: 0.5                                     
                      feat_dim: 2048                                    
              focal_loss_alpha: None                                    
              focal_loss_gamma: 0.5      

In [7]:
# Build dataset
train_ds, test_ds = build_train_test_dataset(
    opt.data_root,
    opt.batch_size,
    tokenizer,
    opt.audio_max_length,
    text_max_length=opt.text_max_length,
    audio_encoder_type=opt.audio_encoder_type,
)

In [8]:
logging.info("Initializing trainer...")
if opt.loss_type == "FocalLoss":
    criterion = losses.FocalLoss(gamma=opt.focal_loss_gamma, alpha=opt.focal_loss_alpha)
    criterion.to(device)
else:
    try:
        criterion = getattr(losses, opt.loss_type)(
            feat_dim=opt.feat_dim,
            num_classes=opt.num_classes,
            lambda_c=opt.lambda_c,
        )
        criterion.to(device)
    except AttributeError:
        raise NotImplementedError("Loss {} is not implemented".format(opt.loss_type))

2023-10-13 09:01:57,382 - root - INFO - Initializing trainer...


In [9]:
trainer = Trainer(
    network=network,
    criterion=criterion,
    log_dir=opt.checkpoint_dir,
)

In [10]:
logging.info("Start training...")
# Build optimizer and criterion
optimizer = optim.Adam(params=trainer.network.parameters(), lr=opt.learning_rate)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=opt.learning_rate_step_size, gamma=opt.learning_rate_gamma)

ckpt_callback = CheckpointsCallback(
    checkpoint_dir=weight_dir,
    save_freq=opt.save_freq,
    max_to_keep=opt.max_to_keep,
    save_best_val=opt.save_best_val,
    save_all_states=opt.save_all_states,
)
trainer.compile(optimizer=optimizer, scheduler=lr_scheduler)
if opt.resume:
    trainer.load_all_states(opt.resume_path)

2023-10-13 09:01:57,428 - root - INFO - Start training...
                            Otherwise, the best model will not be saved.
                            The model will save the lowest validation value if the metric starts with 'loss' and the highest value otherwise.


In [11]:
trainer.fit(train_ds, opt.num_epochs, test_ds, callbacks=[ckpt_callback])

Epoch 0/10
2023-10-13 09:02:00,063 - Training - INFO - Epoch 0/10
loss: 2.0044 acc: 0.0000 : : 1522it [06:04,  4.17it/s]                         
Epoch 0 - loss: 1.0566
2023-10-13 09:08:04,752 - Training - INFO - Epoch 0 - loss: 1.0566
Epoch 0 - acc: 0.5828
2023-10-13 09:08:04,752 - Training - INFO - Epoch 0 - acc: 0.5828
Performing validation...
2023-10-13 09:08:04,771 - Training - INFO - Performing validation...
100%|##########| 679/679 [00:53<00:00, 12.64it/s]
Validation: loss: 1.0861 acc: 0.5803 
2023-10-13 09:08:58,699 - Training - INFO - Validation: loss: 1.0861 acc: 0.5803 
Model loss improve from inf to 1.0861458885037671, Saving model...
2023-10-13 09:08:58,709 - Training - INFO - Model loss improve from inf to 1.0861458885037671, Saving model...
Model acc improve from inf to 0.5802650957290133, Saving model...
2023-10-13 09:09:28,220 - Training - INFO - Model acc improve from inf to 0.5802650957290133, Saving model...
Epoch 1/10
2023-10-13 09:09:58,464 - Training - INFO - Epo