In [None]:
import sys
sys.path.append("/apdcephfs/private_chewu/PALM")

Load PALM from torch.hub (PyTorch >= 1.1):

In [None]:
import torch
from palm.models.palm import PALMModel
# hubconf.py
palm = torch.hub.load('../','palm.base',source='local')
palm.eval()  # disable dropout (or leave in train mode to finetune)

Load PALM (for PyTorch 1.0 or custom models):

In [None]:
from palm.models.palm import PALMModel
palm = PALMModel.from_pretrained('/apdcephfs/share_1351585/FM/NLG/zh/palm_pretrain_checkpoints/', checkpoint_file='checkpoint_best.pt')
palm.eval()  # disable dropout (or leave in train mode to finetune)

Apply Byte-Pair Encoding (BPE) to input text:

In [None]:
inputs = palm.encode(['22日，国家航天局公布祝融号火星车携带的前避障相机和后避障相机拍摄的驶离过程影像。',
                       '上海：6月1日至中下旬全面恢复正常生产生活秩序。'])
# assert tokens.tolist() == [0, 31414, 232, 328, 2]
print(inputs)
# palm.decode(tokens)  # 'Hello world!'
palm.model.encoder(src_tokens=inputs)

In [None]:
inputs = palm.encode(['打破200年军事不结盟传统！瑞典执政党决定支持该国加入北约。',
                       '金正恩亲自前往药店了解情况 下令投入军医稳定平壤供药'])
# assert tokens.tolist() == [0, 31414, 232, 328, 2]
print(inputs)
# palm.decode(tokens)  # 'Hello world!'
palm.model.encoder(src_tokens=inputs['net_input']['src_tokens'], src_lengths=inputs['net_input']['src_lengths'])

Extract features from PALM:

In [None]:
# Extract the last layer's features
last_layer_features = palm.extract_features(tokens)
assert last_layer_features.size() == torch.Size([1, 5, 1024])

# Extract all layer's features from decoder (layer 0 is the embedding layer)
all_layers = palm.extract_features(tokens, return_all_hiddens=True)
assert len(all_layers) == 13
assert torch.all(all_layers[-1] == last_layer_features)

Use PALM for sentence-pair classification tasks:


In [None]:
# Download PALM already finetuned for MNLI
palm = torch.hub.load('pytorch/fairseq', 'palm.large.mnli')
palm.eval()  # disable dropout for evaluation

# Encode a pair of sentences and make a prediction
tokens = palm.encode('PALM is a seq2seq model.', 'PALM is not sequence to sequence.')
palm.predict('mnli', tokens).argmax()  # 0: contradiction

# Encode another pair of sentences
tokens = palm.encode('PALM is denoising autoencoder.', 'PALM is version of autoencoder.')
palm.predict('mnli', tokens).argmax()  # 2: entailment

Register a new (randomly initialized) classification head:

In [None]:
palm.register_classification_head('new_task', num_classes=3)
logprobs = palm.predict('new_task', tokens)

Batched prediction:


In [None]:
import torch
from fairseq.data.data_utils import collate_tokens

palm = torch.hub.load('pytorch/fairseq', 'palm.large.mnli')
palm.eval()

batch_of_pairs = [
    ['PALM is a seq2seq model.', 'PALM is not sequence to sequence.'],
    ['PALM is denoising autoencoder.', 'PALM is version of autoencoder.'],
]

batch = collate_tokens(
    [palm.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
)

logprobs = palm.predict('mnli', batch)
print(logprobs.argmax(dim=1))

Using the GPU:

In [None]:
palm.cuda()
palm.predict('new_task', tokens)

Filling masks:

PALM can be used to fill multiple <mask> tokens in the input.

In [None]:
palm = torch.hub.load('pytorch/fairseq', 'palm.base')
palm.eval()
palm.fill_mask(['The cat <mask> on the <mask>.'], topk=3, beam=10)
# [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))]]