In [1]:
import os
import sys

lib_path = os.path.abspath("").replace("notebooks", "src")
sys.path.append(lib_path)

In [2]:
import datetime
import logging
import random

import numpy as np
import torch
from torch import nn, optim
from transformers import BertTokenizer, RobertaTokenizer,AutoTokenizer

from configs.base import Config
from data.dataloader import build_train_test_dataset
from models import losses, networks
from trainer import Trainer
from utils.configs import get_options
from utils.torch.callbacks import CheckpointsCallback

SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device

device(type='cuda')

In [4]:
opt = get_options(f"{lib_path}/configs/bert_vggish.py")
logging.info("Initializing model...")
# Model
try:
    network = getattr(networks, opt.model_type)(
        num_classes=opt.num_classes,
        num_attention_head=opt.num_attention_head,
        dropout=opt.dropout,
        text_encoder_type=opt.text_encoder_type,
        text_encoder_dim=opt.text_encoder_dim,
        text_unfreeze=opt.text_unfreeze,
        audio_encoder_type=opt.audio_encoder_type,
        audio_encoder_dim=opt.audio_encoder_dim,
        audio_unfreeze=opt.audio_unfreeze,
        audio_norm_type=opt.audio_norm_type,
        fusion_head_output_type=opt.fusion_head_output_type,
    )
    network.to(device)
except AttributeError:
    raise NotImplementedError("Model {} is not implemented".format(opt.model_type))

2023-10-11 10:03:49,803 - root - INFO - Initializing model...


In [5]:
network

MMSERA(
  (text_encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [6]:

logging.info("Initializing checkpoint directory and dataset...")
if opt.text_encoder_type == "bert":
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
elif opt.text_encoder_type == "roberta":
    tokenizer = AutoTokenizer.from_pretrained("SamLowe/roberta-base-go_emotions")
else:
    raise NotImplementedError("Tokenizer {} is not implemented".format(opt.text_encoder_type))

# Preapre the checkpoint directory
opt.checkpoint_dir = checkpoint_dir = os.path.join(
    os.path.abspath(opt.checkpoint_dir),
    opt.name,
    datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
)
log_dir = os.path.join(checkpoint_dir, "logs")
weight_dir = os.path.join(checkpoint_dir, "weights")
os.makedirs(log_dir, exist_ok=True)
os.makedirs(weight_dir, exist_ok=True)
opt.save(opt)

2023-10-11 10:03:54,194 - root - INFO - Initializing checkpoint directory and dataset...
2023-10-11 10:03:54,689 - root - INFO - 
             audio_encoder_dim: 128                                     
            audio_encoder_type: vggish                                  
              audio_max_length: 50                                      
               audio_norm_type: layer_norm                              
                audio_unfreeze: False                                   
                    batch_size: 4                                       
                checkpoint_dir: d:\SER_ICIIT_2024\notebooks\checkpoints\bert_vggish_MMSERA\20231011-100354
                     data_root: D:/MELD/MELD                            
                       dropout: 0.5                                     
                      feat_dim: 2048                                    
              focal_loss_alpha: None                                    
              focal_loss_gamma: 0

In [7]:
# Build dataset
train_ds, test_ds = build_train_test_dataset(
    opt.data_root,
    opt.batch_size,
    tokenizer,
    opt.audio_max_length,
    text_max_length=opt.text_max_length,
    audio_encoder_type=opt.audio_encoder_type,
)

In [8]:
logging.info("Initializing trainer...")
if opt.loss_type == "FocalLoss":
    criterion = losses.FocalLoss(gamma=opt.focal_loss_gamma, alpha=opt.focal_loss_alpha)
    criterion.to(device)
else:
    try:
        criterion = getattr(losses, opt.loss_type)(
            feat_dim=opt.feat_dim,
            num_classes=opt.num_classes,
            lambda_c=opt.lambda_c,
        )
        criterion.to(device)
    except AttributeError:
        raise NotImplementedError("Loss {} is not implemented".format(opt.loss_type))

2023-10-11 10:03:56,349 - root - INFO - Initializing trainer...


In [9]:
trainer = Trainer(
    network=network,
    criterion=criterion,
    log_dir=opt.checkpoint_dir,
)

In [10]:
logging.info("Start training...")
# Build optimizer and criterion
optimizer = optim.Adam(params=trainer.network.parameters(), lr=opt.learning_rate)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=opt.learning_rate_step_size, gamma=opt.learning_rate_gamma)

ckpt_callback = CheckpointsCallback(
    checkpoint_dir=weight_dir,
    save_freq=opt.save_freq,
    max_to_keep=opt.max_to_keep,
    save_best_val=opt.save_best_val,
    save_all_states=opt.save_all_states,
)
trainer.compile(optimizer=optimizer, scheduler=lr_scheduler)
if opt.resume:
    trainer.load_all_states(opt.resume_path)

2023-10-11 10:03:56,375 - root - INFO - Start training...
                            Otherwise, the best model will not be saved.
                            The model will save the lowest validation value if the metric starts with 'loss' and the highest value otherwise.


In [11]:
trainer.fit(train_ds, opt.num_epochs, test_ds, callbacks=[ckpt_callback])

Epoch 0/250
2023-10-11 10:03:57,142 - Training - INFO - Epoch 0/250
  0%|          | 1/1521 [00:00<00:09, 166.74it/s]


AttributeError: 'list' object has no attribute 'to'