based on model.py

In [2]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch.nn as nn

class policy_network(nn.Module):
    
    def __init__(self, model_config="bert-base-uncased", add_linear=False, embedding_size=128, freeze_encoder=True, context_net=False):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_config)
        print("model_config:", model_config)
        self.model = AutoModelForTokenClassification.from_pretrained(model_config)
        
        # Freeze transformer encoder and only train the linear layer
        if freeze_encoder:
            for param in self.model.parameters():
                param.requires_grad = False

        if add_linear:
            # Add an additional small, adjustable linear layer on top of BERT tuned through RL
            self.embedding_size = embedding_size
            if context_net:
                input_dim = self.model.config.hidden_size * 2
            else:
                input_dim = self.model.config.hidden_size
            self.linear = nn.Linear(input_dim,
                                    embedding_size)  # 768 for bert-base-uncased, distilbert-base-uncased
        else:
            self.linear = None
            
    def forward(self, input_list, bert_forward=True, linear_forward=True):
        if bert_forward:
            input = self.tokenizer(input_list, truncation=True, padding=True, return_tensors="pt").to(self.model.device)
            # print(f"input: {input}")
            output = self.model(**input, output_hidden_states=True)
            # Get last layer hidden states
            last_hidden_states = output.hidden_states[-1]
            # Get [CLS] hidden states
            sentence_embedding = last_hidden_states[:, 0, :]  # len(input_list) x hidden_size
            # print(f"sentence_embedding: {sentence_embedding}")

        if linear_forward:
            if self.linear:
                if bert_forward:
                    sentence_embedding = self.linear(sentence_embedding)  # len(input_list) x embedding_size
                else:
                    sentence_embedding = self.linear(input_list)
        return sentence_embedding

In [3]:
model = policy_network(add_linear=True,
                       freeze_encoder=True)

model_config: bert-base-uncased


Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
