In [5]:
from base_classifier import BaseClassifier
from multilabel import MultiLabelClassifier
from multiclass import MultiClassClassifier
from tokenizer import Tokenizer
from datamodule import MedDataModule, Collator

import pathlib
import yaml
import pandas as pd
from sys import path
from pathlib import Path
from argparse import ArgumentParser, Namespace
from collections import OrderedDict

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.seed import seed_everything
from torchnlp.random import set_seed
from torchmetrics.functional import f1, precision_recall

In [1]:
def load_hparams(experiment_dir: str):
    hparams_file = experiment_dir + "/hparams.yaml"
    hparams = yaml.load(open(hparams_file).read(), Loader=yaml.FullLoader)
    # print(Namespace(**hparams))
    return Namespace(**hparams)


def load_model(experiment_dir: str, dataset, hparams, tokenizer, collator, num_classes):
    """ Function that loads the model from an experiment folder.
    :param experiment_dir: Path to the experiment folder.
    Return:
        - Pretrained model.
    """
    experiment_path = Path(experiment_dir + "/checkpoints/")
    
    # hparams_file = experiment_dir + "/hparams.yaml"
    # hparams = dotdict(yaml.load(open(hparams_file).read(), Loader=yaml.FullLoader))

    checkpoints = [
        file.name
        for file in experiment_path.iterdir()
        if file.name.endswith(".ckpt")
    ]
    checkpoint_path = experiment_path / checkpoints[-1]
    
    classifier = MultiLabelClassifier if dataset == "hoc" else MultiClassClassifier
    
    model = classifier.load_from_checkpoint(
        checkpoint_path, hparams=hparams, tokenizer=tokenizer,
        collator=collator, encoder_model=hparams.encoder_model,
        batch_size=hparams.batch_size,
        num_frozen_epochs=hparams.num_frozen_epochs,
        #  label_encoder,
        encoder_learning_rate=hparams.encoder_learning_rate, 
        learning_rate=hparams.learning_rate,
    )

    
    # Make sure model is in prediction mode
    model.eval()
    model.freeze()
    return model

def prototype(hparams):
    seed_everything(69)
    
    tokenizer = Tokenizer(hparams.encoder_model)
    collator = Collator(tokenizer)
    datamodule = MedDataModule(
        tokenizer, collator, hparams.data_path,
        hparams.dataset, hparams.batch_size, 
        hparams.num_workers,
    )
    
    desc_tokens = datamodule.desc_tokens
    num_classes = datamodule.num_classes
    train_size = datamodule.size(dim=0)
    print("Finished loading data!")
    
    
    if hparams.dataset == 'hoc':
        model = MultiLabelClassifier(
            desc_tokens, tokenizer, collator, num_classes, train_size, hparams, **vars(hparams)
        )
    else:
        model = MultiClassClassifier(
            desc_tokens, tokenizer, collator, num_classes, train_size, hparams, **vars(hparams)
        )
        
    return model, datamodule, hparams

In [2]:
hparams = Namespace(
    encoder_model="bert-base-cased",
    data_path="./project/data",
    dataset="mtc",
    batch_size=2,
    num_workers=2,
    random_sampling=False,
    num_frozen_epochs=0,
    encoder_learning_rate=1e-05,
    learning_rate=3e-05,
    tgt_txt_col="TEXT",
    tgt_lbl_col="LABEL",
    n_lbl_attn_layer=1,
    static_desc_emb=True,
    weight_decay_encoder=0.05,
    weight_decay_nonencoder=0.1,
    label_attn_lr=0.0002,
)

SyntaxError: invalid syntax (<ipython-input-2-f65413967027>, line 1)