In [None]:
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_path",
        type=str,
        required=True,
        help="The path to data used for prediction.",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="westlake-repl/SaProt_650M_AF2",
        help="The name of a pretrained model or path to a model which you want to use for training. You can use your local models or models uploaded to hugging face.",
    )
    parser.add_argument(
        "--architecture",
        type=str,
        default="SaProt",
        help="The name of a model architecture. 'ESM2', 'SaProt' or 'LSTM'.",
    )
    parser.add_argument(
        "--model_path",
        type=str,
        required=True,
        help="The path to a model which you want to use for prediction.",
    )
    parser.add_argument(
        "--batch_size", type=int, default=4, help="Batch size."
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Set seed for reproducibility.",
    )
    parser.add_argument(
        "--use_amp",
        action="store_true",
        default=False,
        help="Use amp for mixed precision training.",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=4,
        help="Number of workers for data loading.",
    )
    parser.add_argument(
        "--max_length",
        type=int,
        default=512,
        help="Max length of input sequence. Two tokens are used fo <cls> and <eos> tokens. So the actual length of input sequence is max_length - 2. Padding or truncation is applied to make the length of input sequence equal to max_length.",
    )
    parser.add_argument(
        "--used_sequence",
        type=str,
        default="left",
        help="How to use input sequence. 'left': use the left part of the sequence, 'right': use the right part of the sequence, 'both': use both side of the sequence, 'internal': use the internal part of the sequence.",
    )
    parser.add_argument(
        "--padding_side",
        type=str,
        default="right",
        help="Padding side. 'right' or 'left'.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="/home2/sagawa/protein-half-life-prediction/ver20_56_2/",
        help="Output directory.",
    )
    parser.add_argument(
        "--task",
        type=str,
        default="classification",
        help="Task. 'classification' or 'regression'.",
    )
    parser.add_argument(
        "--sequence_col",
        type=str,
        default="aa_foldseek",
        help="The column name of amino acid sequence.",
    )

    return parser.parse_args()

In [11]:
class config:
    data_path = "/home2/sagawa/protein-half-life-prediction/ver20_56_2/data/ver20_56_2.csv"
    model="facebook/esm2_t33_650M_UR50D"
    architecture = "ESM2"
    model_path = "/home2/sagawa/protein-half-life-prediction/ver20_56_2/outputs/ver20_56_2_20210909_1/model_0"
    batch_size = 4
    seed = 42
    use_amp = False
    num_workers = 4
    max_length = 512
    used_sequence = "left"
    padding_side = "right"
    output_dir = "/home2/sagawa/protein-half-life-prediction/ver20_56_2/"
    task = "classification"
    sequence_col = "aa_foldseek"
cfg = config()

In [33]:
from models import PLTNUM
import torch
model = PLTNUM(cfg)
model.load_state_dict(torch.load("/home2/sagawa/PLTNUM/classification/model_fold0.pth", map_location=torch.device('cpu'))["model"])
model

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


PLTNUM(
  (model): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 1280, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 1280, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-32): 33 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=1280, out_features=1280, bias=True)
              (key): Linear(in_features=1280, out_features=1280, bias=True)
              (value): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((1280,), eps=1e-05, elemen

In [35]:
torch.load("/home2/sagawa/PLTNUM/classification/model_fold0.pth", map_location=torch.device('cpu'))["model"]

OrderedDict([('model.embeddings.word_embeddings.weight',
              tensor([[ 0.0540, -0.0574, -0.1302,  ..., -0.2233,  0.1189, -0.0651],
                      [ 0.0614,  0.0292, -0.1028,  ..., -0.0330,  0.0668, -0.0575],
                      [-0.0896, -0.0439, -0.0575,  ..., -0.0582,  0.0655, -0.1094],
                      ...,
                      [ 0.0023,  0.0143,  0.0514,  ..., -0.0193,  0.0112,  0.0053],
                      [ 0.0731,  0.0470,  0.0346,  ...,  0.1118,  0.0465, -0.0243],
                      [ 0.0492,  0.0254, -0.1112,  ..., -0.0257,  0.0599, -0.0614]])),
             ('model.embeddings.position_embeddings.weight',
              tensor([[ 0.0026,  0.0083,  0.0133,  ..., -0.0090,  0.0321,  0.0109],
                      [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
                      [ 0.0140,  0.0045,  0.0005,  ...,  0.0225, -0.0110, -0.0185],
                      ...,
                      [ 0.0315,  0.0066, -0.0124,  ..., -0.0021, -0.0

In [None]:
!mkdir /home2/sagawa/PLTNUM/classification_automodel

In [30]:
config = AutoConfig.from_pretrained("facebook/esm2_t33_650M_UR50D")
config.save_pretrained("/home2/sagawa/PLTNUM/classification_automodel")
# save model
torch.save(model.state_dict(), "/home2/sagawa/PLTNUM/classification_automodel/pytorch_model.bin")

In [41]:
from transformers import PreTrainedModel, AutoConfig, AutoModel
import torch.nn as nn
class PLTNUM(PreTrainedModel):
    config_class = AutoConfig
    def __init__(self, config, cfg):
        super().__init__(config)
        self.cfg = cfg
        self.config = config
        self.model = AutoModel.from_pretrained(self.config._name_or_path)
        self.fc_dropout1 = nn.Dropout(0.8)
        self.fc_dropout2 = nn.Dropout(0.4 if cfg.task == "classification" else 0.8)
        self.fc_dropout2 = nn.Dropout(0.8)  
        self.fc = nn.Linear(self.config.hidden_size, 1)
        self._init_weights(self.fc)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, inputs):
        outputs = self.model(**inputs)
        last_hidden_states = outputs[0][:, 0]
        output = (
            self.fc(self.fc_dropout1(last_hidden_states))
            + self.fc(self.fc_dropout2(last_hidden_states))
        ) / 2
        return output

    def create_embedding(self, inputs):
        outputs = self.model(**inputs)
        last_hidden_states = outputs[0][:, 0]
        return last_hidden_states

In [7]:
model = PLTNUM.from_pretrained("facebook/esm2_t33_650M_UR50D")

In [42]:
model = PLTNUM.from_pretrained("/home2/sagawa/PLTNUM/classification_automodel", cfg)

Some weights of EsmModel were not initialized from the model checkpoint at /home2/sagawa/PLTNUM/classification_automodel and are newly initialized: ['contact_head.regression.bias', 'contact_head.regression.weight', 'embeddings.position_embeddings.weight', 'embeddings.word_embeddings.weight', 'encoder.emb_layer_norm_after.bias', 'encoder.emb_layer_norm_after.weight', 'encoder.layer.0.LayerNorm.bias', 'encoder.layer.0.LayerNorm.weight', 'encoder.layer.0.attention.LayerNorm.bias', 'encoder.layer.0.attention.LayerNorm.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.rotary_embeddings.inv_freq', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.intermediate.dense.bias',

In [43]:
model.state_dict()

OrderedDict([('model.embeddings.word_embeddings.weight',
              tensor([[ 0.0540, -0.0574, -0.1302,  ..., -0.2233,  0.1189, -0.0651],
                      [ 0.0614,  0.0292, -0.1028,  ..., -0.0330,  0.0668, -0.0575],
                      [-0.0896, -0.0439, -0.0575,  ..., -0.0582,  0.0655, -0.1094],
                      ...,
                      [ 0.0023,  0.0143,  0.0514,  ..., -0.0193,  0.0112,  0.0053],
                      [ 0.0731,  0.0470,  0.0346,  ...,  0.1118,  0.0465, -0.0243],
                      [ 0.0492,  0.0254, -0.1112,  ..., -0.0257,  0.0599, -0.0614]])),
             ('model.embeddings.position_embeddings.weight',
              tensor([[ 0.0026,  0.0083,  0.0133,  ..., -0.0090,  0.0321,  0.0109],
                      [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
                      [ 0.0140,  0.0045,  0.0005,  ...,  0.0225, -0.0110, -0.0185],
                      ...,
                      [ 0.0315,  0.0066, -0.0124,  ..., -0.0021, -0.0

In [44]:
model.fc_dropout2

Dropout(p=0.8, inplace=False)

In [40]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("/home2/sagawa/PLTNUM/classification_automodel")

token_embeddingをresizeしてweightを保存

In [45]:
tokenizer.add_special_tokens({"additional_special_tokens": ["@@@@@@", "ooo"]})
tokenizer

EsmTokenizer(name_or_path='/home2/sagawa/PLTNUM/classification_automodel', vocab_size=33, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<eos>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'cls_token': '<cls>', 'mask_token': '<mask>', 'additional_special_tokens': ['@@@@@@', 'ooo']}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<cls>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("<eos>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32: AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	33: AddedToken("@@@@@@", rstrip=False,

In [50]:
model.model.resize_token_embeddings(len(tokenizer))
model.config

EsmConfig {
  "_name_or_path": "/home2/sagawa/PLTNUM/classification_automodel",
  "architectures": [
    "EsmForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.0,
  "classifier_dropout": null,
  "emb_layer_norm_before": false,
  "esmfold_config": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1280,
  "initializer_range": 0.02,
  "intermediate_size": 5120,
  "is_folding_model": false,
  "layer_norm_eps": 1e-05,
  "mask_token_id": 32,
  "max_position_embeddings": 1026,
  "model_type": "esm",
  "num_attention_heads": 20,
  "num_hidden_layers": 33,
  "pad_token_id": 1,
  "position_embedding_type": "rotary",
  "token_dropout": true,
  "torch_dtype": "float32",
  "transformers_version": "4.38.1",
  "use_cache": true,
  "vocab_list": null,
  "vocab_size": 33
}

In [51]:
torch.save(model.state_dict(), "/home2/sagawa/PLTNUM/classification_automodel/pytorch_model.bin")

In [52]:
model = PLTNUM.from_pretrained("/home2/sagawa/PLTNUM/classification_automodel", cfg)

Some weights of EsmModel were not initialized from the model checkpoint at /home2/sagawa/PLTNUM/classification_automodel and are newly initialized: ['contact_head.regression.bias', 'contact_head.regression.weight', 'embeddings.position_embeddings.weight', 'embeddings.word_embeddings.weight', 'encoder.emb_layer_norm_after.bias', 'encoder.emb_layer_norm_after.weight', 'encoder.layer.0.LayerNorm.bias', 'encoder.layer.0.LayerNorm.weight', 'encoder.layer.0.attention.LayerNorm.bias', 'encoder.layer.0.attention.LayerNorm.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.rotary_embeddings.inv_freq', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.intermediate.dense.bias',

RuntimeError: Error(s) in loading state_dict for PLTNUM:
	size mismatch for model.embeddings.word_embeddings.weight: copying a param with shape torch.Size([35, 1280]) from checkpoint, the shape in current model is torch.Size([33, 1280]).
	You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

In [56]:
model.config.vocab_size = len(tokenizer)
model.config.save_pretrained("/home2/sagawa/PLTNUM/classification_automodel")

In [57]:
model = PLTNUM.from_pretrained("/home2/sagawa/PLTNUM/classification_automodel", cfg)

Some weights of EsmModel were not initialized from the model checkpoint at /home2/sagawa/PLTNUM/classification_automodel and are newly initialized: ['contact_head.regression.bias', 'contact_head.regression.weight', 'embeddings.position_embeddings.weight', 'embeddings.word_embeddings.weight', 'encoder.emb_layer_norm_after.bias', 'encoder.emb_layer_norm_after.weight', 'encoder.layer.0.LayerNorm.bias', 'encoder.layer.0.LayerNorm.weight', 'encoder.layer.0.attention.LayerNorm.bias', 'encoder.layer.0.attention.LayerNorm.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.rotary_embeddings.inv_freq', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.intermediate.dense.bias',

In [None]:
tokenizer.save_pretrained