In [1]:
import os
import argparse
import datetime
import json
import time
import warnings
from pathlib import Path
from typing import Dict, List

import torch
from tqdm import tqdm

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from utils import calculate_rouge, chunks, parse_numeric_n_bool_cl_kwargs, use_task_specific_params


In [3]:
model_name = './output/2020-12-15-01-44-14/best_tfmr'
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to('cuda')


Some weights of the model checkpoint at ./output/2020-12-15-01-44-14/best_tfmr were not used when initializing PegasusForConditionalGeneration: ['model.encoder.embed_speaker.weight']
- This IS expected if you are initializing PegasusForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing PegasusForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
state_dict = torch.load(os.path.join(model_name,"pytorch_model.bin"), map_location="cpu")


In [47]:
print('model.encoder.embed_speaker.weight' in state_dict.keys())

True


In [14]:
state_dict['model.encoder.embed_speaker.weight']

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.4873,  0.4773,  0.6709,  ..., -0.0336,  1.0234, -1.6810],
        [-0.7433,  2.3412,  0.5041,  ...,  0.5177, -0.1096, -0.6892],
        ...,
        [ 1.1140,  0.2462,  1.6674,  ..., -0.0325, -0.1390, -0.9116],
        [-0.2555,  2.9544, -1.4787,  ..., -0.4960,  0.4160,  1.1340],
        [ 1.3796, -1.3169,  0.1431,  ...,  0.2244, -1.1140, -0.1873]])

In [16]:
from speaker_embed_encoder import BartEncoderWithSpeakerEmbedding
speaker_encoder = BartEncoderWithSpeakerEmbedding(model.config, model.model.shared, use_turn_embeds=False).to('cuda')

In [38]:
for name, param in model.model.encoder.named_parameters():
    print(name)
    if name=="embed_positions.weight":
        print(param==model.model.encoder.embed_positions.weight)
    speaker_encoder.state_dict()[name][:] = param


embed_tokens.weight
embed_positions.weight
tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]], device='cuda:0')
layers.0.self_attn.k_proj.weight
layers.0.self_attn.k_proj.bias
layers.0.self_attn.v_proj.weight
layers.0.self_attn.v_proj.bias
layers.0.self_attn.q_proj.weight
layers.0.self_attn.q_proj.bias
layers.0.self_attn.out_proj.weight
layers.0.self_attn.out_proj.bias
layers.0.self_attn_layer_norm.weight
layers.0.self_attn_layer_norm.bias
layers.0.fc1.weight
layers.0.fc1.bias
layers.0.fc2.weight
layers.0.fc2.bias
layers.0.final_layer_norm.weight
layers.0.final_layer_norm.bias
layers.1.self_attn.k_proj.weight
layers.1.self_attn.k_proj.bias
layers.1.self_attn.v_proj.weight
layers.1.self_attn.v_proj.bias
layers.1.self_

In [39]:
speaker_encoder.embed_positions.weight == model.model.encoder.embed_positions.weight

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]], device='cuda:0')

In [42]:
speaker_encoder.embed_speaker.weight = torch.nn.Parameter(state_dict['model.encoder.embed_speaker.weight'])

In [44]:
model.model.encoder = speaker_encoder

In [48]:
speaker_encoder.embed_speaker.weight

Parameter containing:
tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.4873,  0.4773,  0.6709,  ..., -0.0336,  1.0234, -1.6810],
        [-0.7433,  2.3412,  0.5041,  ...,  0.5177, -0.1096, -0.6892],
        ...,
        [ 1.1140,  0.2462,  1.6674,  ..., -0.0325, -0.1390, -0.9116],
        [-0.2555,  2.9544, -1.4787,  ..., -0.4960,  0.4160,  1.1340],
        [ 1.3796, -1.3169,  0.1431,  ...,  0.2244, -1.1140, -0.1873]],
       requires_grad=True)