In [1]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import torch
import torch.nn as nn

from transformers import BertModel

# 1. Build the model

## 1.1. Setup device agnostic code

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## 1.2. Construct the model

In Elden Ring, Godrick is the first boss and shardbearer. A descendant of the Golden Lineage, the bloodline which began with Queen Marika the Eternal and her first consort Godfrey, he took up residence in Stormveil after the Ring was shattered, a place where he could practice his art of grafting; a grotesque act which involves attaching parts of other living beings to oneself in order to gain power.

<div align="center">
  <img src="../../assets/elden-ring-godrick.jpg" width="700"/>
</div>

One common and clean way to “graft” two networks together in PyTorch is to wrap them in a single `nn.Module` subclass. In other words, a new class is created that holds a pre-processing model, a core model and a post-processing model as submodules, and then a `forward()` method is defined that uses the output of each previous network as the input to the next.

In [None]:
class Preprocessor(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, encyphered_input):
        decyphered_input = ...
        return decyphered_input

class Postprocessor(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, decyphered_output):
        encyphered_output = ...
        return encyphered_output

In [None]:
preprocessor = Preprocessor()
preprocessor.load_state_dict(torch.load('./models/preprocessor.pt'))

bert_model = BertModel.from_pretrained('bert-base-uncased')

postprocessor = Postprocessor()
postprocessor.load_state_dict(torch.load('./models/postprocessor.pt'))

In [None]:
class GodrickModel(nn.Module):
    def __init__(self, preprocessor, bert_model, postprocessor):
        super().__init__()
        self.preprocessor = preprocessor
        self.bert = bert_model
        self.postprocessor = postprocessor

    def forward(self, raw_input):
        # Step 1: Preprocess the raw input (cryptic tokens to BERT tokens)
        preprocessed_input = self.preprocessor(raw_input)

        # Step 2: Pass the preprocessed input into BERT.
        # for HF models, `preprocessed_input` should be a dictionary:
        # e.g., preprocessed_input = {
        #    "input_ids": ...,
        #    "attention_mask": ...,
        #    "token_type_ids": ... (optional)
        # }
        bert_output = self.bert(**preprocessed_input)

        # Step 3: Post-process BERT's output (BERT tokens to cryptic tokens)
        postprocessed_output = self.postprocessor(bert_output)

        return postprocessed_output

In [None]:
model0 = GodrickModel().to(device)
model0.state_dict()

## 1.3. Making test predictions

In [None]:
X_test = torch.tensor([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]]).to(device)

In [None]:
# `untrained_logits`: the output of the model before any additional handling
with torch.inference_mode():
  untrained_logits = model0(X_test.to(device))
untrained_logits.shape, untrained_logits[:5].squeeze()

In [None]:
# `untrained_preds_probs`: the probability of the logits after applying the softmax function

In [None]:
# `untrained_preds`: the predictions after applying the argmax function

## 1.4. Evaluation metrics

In [None]:
from sklearn.metrics import accuracy_score, f1_score

def compute_metrics(y_true, y_pred):
  f1 = f1_score(y_true, y_pred, average="weighted")
  acc = accuracy_score(y_true, y_pred)
  return {"accuracy": acc, "f1": f1}

## 1.5. Save and load trained model

In [None]:
from pathlib import Path

Path("models").mkdir(exist_ok=True)
torch.save(obj=model0.state_dict(), f='./models/godrick_model.pt')

In [None]:
!ls -l ./models

In [None]:
loaded_model = GodrickModel()
loaded_model.load_state_dict(torch.load('./models/godrick_model.pt'))

model0.state_dict(), loaded_model.state_dict()