In [1]:
!pip install -r ../requirements.txt

Collecting esm==3.1.3 (from -r ../requirements.txt (line 1))
  Using cached esm-3.1.3-py3-none-any.whl.metadata (15 kB)
Collecting transformers==4.46.3 (from -r ../requirements.txt (line 2))
  Using cached transformers-4.46.3-py3-none-any.whl.metadata (44 kB)
Collecting sentencepiece==0.2.0 (from -r ../requirements.txt (line 3))
  Using cached sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Collecting tables==3.10.1 (from -r ../requirements.txt (line 4))
  Using cached tables-3.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.2 kB)
Collecting pingouin==0.5.5 (from -r ../requirements.txt (line 5))
  Using cached pingouin-0.5.5-py3-none-any.whl.metadata (19 kB)
Collecting POT==0.9.5 (from -r ../requirements.txt (line 6))
  Using cached POT-0.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (34 kB)
Collecting colorcet==3.1.0 (from -r ../requirements.txt (line 7))
  Using cached colorcet-3.1.

In [2]:
import os, torch, time, gc, itertools
import pandas as pd
import numpy as np
import torch.optim as optim
from torch.nn import BCEWithLogitsLoss
from data_module import parallel_load, layernorm
from model_module import Classifier, set_dataloader, train_model

In [None]:
data_path = '../data'
model_path = '../model'
result_path = '../result/history'

os.makedirs(model_path, exist_ok=True)
os.makedirs(result_path, exist_ok=True)

embed_ver = ['esm3', 'esm2', 'bert', 't5']
embed_types = ['allmean', 'aamean', 'bos', 'eos', 'first', 'center', 'last']
set_ver = 'tr'
batch_size = 256

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
time_total = time.time()
for emb_ver in embed_ver:
    print(f">>> {emb_ver.upper()} <<<")
    time_emb = time.time()
    
    # load training datasets of the embedding version
    print("> Dataset loading...")
    time_start = time.time()
    df_info = pd.read_hdf(os.path.join(data_path, f'emb-{emb_ver}_{set_ver}.h5'), key='info')
    dfs = parallel_load(embed_types, data_path, 'emb', emb_ver, set_ver)
    gc.collect()
    print(f"--- Dataset loading complete: {time.time() - time_start:.1f} sec ---")
    
    for r in range(1, len(embed_types[:-2]) + 1):
        emb_types = embed_types[:-2] if r > 1 else embed_types
        
        for df_keys in itertools.combinations(emb_types, r):
            comb_ver = '_'.join([df_key for df_key in df_keys])
            if 't5' in emb_ver.lower() and 'bos' in comb_ver.lower():
                continue
            
            #### Data preprocessing ####
            time_start = time.time()
            for integ_ver in ['indiv', 'sum', 'cat']:
                if r == 1 and integ_ver != 'indiv':
                    continue
                
                elif r > 1 and integ_ver == 'indiv':
                    continue
                
                else:
                    model_ver = f'{emb_ver}-{integ_ver}-{comb_ver}'
                    print(f"> '{model_ver}' data-loader setting...")
                    
                    X_all = dfs[df_keys[0]].to_numpy()
                    if integ_ver == 'sum':
                        # sum normalized dataset
                        X_all = layernorm(X_all)
                        for df_key in df_keys[1:]:
                            X_all += layernorm(dfs[df_key].to_numpy())
                    elif integ_ver == 'cat':
                        # concatenate dataset
                        for df_key in df_keys[1:]:
                            X_all = np.concatenate([X_all, dfs[df_key].to_numpy()], axis=1)

                # get label data
                y_all = df_info['ess'].to_numpy('float')
                print(f"Input & output shape: {X_all.shape}, {y_all.shape}")

                # set dataloader
                train_loader, valid_loader, pos_weight = set_dataloader(
                    X_all, y_all, batch_size, device
                )
                print(f"--- Dataloader setting complete: {time.time() - time_start:.1f} sec ---")


                #### Model training ####
                # set model name
                print(f"> '{model_ver}' model training...")
                
                # generate modeling instances
                model = Classifier(input_size=X_all.shape[-1]).to(device)
                model = torch.compile(model)
                criterion = BCEWithLogitsLoss(pos_weight=pos_weight)
                optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
                
                # training the model
                time_start = time.time()
                history = train_model(
                    model=model,
                    train_loader=train_loader,
                    valid_loader=valid_loader,
                    criterion=criterion,
                    optimizer=optimizer,
                    model_path=model_path,
                    model_name=model_ver
                )
                print(f"--- Training compelete: {time.time() - time_start:.1f} sec ---")

                # get training history
                df_hist = pd.DataFrame(history)
                best_idx = df_hist['valid_mcc'].argmax()
                print(f"[{model_ver}]")
                print(f"- Best epoch: {df_hist.loc[best_idx, 'epoch']} |",
                      f"Best loss: {df_hist.loc[best_idx, 'valid_loss']:.4f} |",
                      f"Best metric: {df_hist.loc[best_idx, 'valid_mcc']:.4f}\n")
                # save the training history
                df_hist.to_csv(
                    os.path.join(result_path, f"{model_ver}.csv"),
                    index=False
                )
                
                # release memory
                del model, criterion, optimizer, X_all, y_all
                torch.cuda.empty_cache()
                gc.collect()
    
    print(f"=== '{emb_ver}' classifiers training complete: {time.time() - time_emb:.1f} sec ===\n")

print(f"=== All training complete: {time.time() - time_total:.1f} sec ===")