In [4]:
%cd ..

/home/pablo/nlp-course/assignment


In [5]:
import sys
sys.path.append("src")

In [12]:
import torch
import json
import os

from src.models.model_builder import ModelBuilder
from src.heads.classification_head import ModelWithClassificationHead
from src.data_loaders.pan23 import PAN23DataModule

In [13]:
def extract_torch_model(config, checkpoint_path, save_path, task):
    """
    Takes the checkpoint of a pytorch lightning model and extracts the model
    to a file, so it can be uploaded to huggingface hub. It also requires the
    config that was used for training.
    """
    model_name = config["model_name"]
    model_params = config["model_params"]
    data_module_params = config["pan_train_params"]["data_module_params"]

    data_module_params["data_path"] = os.path.join(data_module_params["data_path"], task)

    model_path = os.path.join(save_path, f"{task}.pt")

    checkpoint = torch.load(checkpoint_path)

    state_dict = {}
    for key, value in checkpoint['state_dict'].items():
        state_dict[key.replace('model.', '', 1)] = value
    state_dict.keys()

    data_module = PAN23DataModule.from_joint_config(data_module_params)

    model = ModelBuilder.build(model_name, model_params, data_module.get_vocab_size(), data_module.get_padding_idx())

    model_with_class_head = ModelWithClassificationHead(
        model,
        model.get_out_embedding_dim(),
        dropout_p=0.25,
        ff_dim=2048,
        joint_pairs=True,
    )

    model_with_class_head.load_state_dict(state_dict)

    torch.save(model_with_class_head, model_path)

In [19]:
checkpoint_paths = [
    "out/roberta-base/finetuned/task1/lightning_logs/version_12/checkpoints/epoch=9-val_f1_score=0.98.ckpt",
    "out/roberta-base/finetuned/task2/lightning_logs/version_6/checkpoints/epoch=8-val_f1_score=0.76.ckpt",
    "out/roberta-base/finetuned/task3/lightning_logs/version_5/checkpoints/epoch=7-val_f1_score=0.68.ckpt",
]
for ind, checkpoint_path in enumerate(checkpoint_paths):
    with open("configs/roberta-base.json", "r") as f:
        config = json.load(f)
    task = f"task{ind+1}"
    extract_torch_model(config, checkpoint_path, "selected_models/roberta-base/", "task3")

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
