forked from mindspore-courses/MindSpore-Tutorial
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sample.py
95 lines (75 loc) · 3.27 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""测试"""
import argparse
import json
import matplotlib.pyplot as plt
import mindspore
import numpy as np
from PIL import Image
from mindspore import Tensor
from mindspore.dataset.transforms import Compose
from mindspore.dataset.vision import transforms
from model import EncoderCNN, DecoderRNN
def load_image(image_path, transform=None):
"""加载图片"""
image = Image.open(image_path).convert('RGB')
image = image.resize([256, 256], Image.LANCZOS)
if transform is not None:
image = transform(image)
image = Tensor(image[0])
image = image.unsqueeze(0)
return image
def main(_args):
"""主函数"""
# Image preprocessing
transform = Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225), is_hwc=False)])
# Load vocabulary wrapper
with open(_args.json_path, 'rb') as f:
vocab = json.load(f)
# Build models
encoder = EncoderCNN(_args.embed_size) # eval mode (batchnorm uses moving mean/variance)
decoder = DecoderRNN(_args.embed_size, _args.hidden_size, len(vocab), _args.num_layers)
encoder.set_train(False)
decoder.set_train(False)
# Load the trained model parameters
mindspore.load_param_into_net(encoder, mindspore.load_checkpoint(_args.encoder_path))
mindspore.load_param_into_net(decoder, mindspore.load_checkpoint(_args.decoder_path))
# Prepare an image
image = load_image(_args.image, transform)
image_tensor = mindspore.Tensor(image)
# Generate an caption from the image
# image_tensor = ops.ones(image_tensor.shape)
feature = encoder(image_tensor)
print(feature)
sampled_ids = decoder.sample(feature)
sampled_ids = sampled_ids[0].asnumpy() # (1, max_seq_length) -> (max_seq_length)
# Convert word_ids to words
sampled_caption = []
for word_id in sampled_ids:
word = list(vocab.keys())[list(vocab.values()).index(word_id)]
# word = vocab[word_id]
sampled_caption.append(word)
if word == '<end>':
break
sentence = ' '.join(sampled_caption)
# Print out the image and the generated caption
print(sentence)
image = Image.open(_args.image)
plt.imshow(np.asarray(image))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image', type=str, required=True, help='input image for generating caption')
parser.add_argument('--encoder_path', type=str, default='models/encoder-1-100.ckpt',
help='path for trained encoder')
parser.add_argument('--decoder_path', type=str, default='models/decoder-1-100.ckpt',
help='path for trained decoder')
parser.add_argument('--json_path', type=str,
default='../../../data/COCO/mindrecord/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json')
# Model parameters (should be same as paramters in train.py)
parser.add_argument('--embed_size', type=int, default=256, help='dimension of word embedding vectors')
parser.add_argument('--hidden_size', type=int, default=512, help='dimension of lstm hidden states')
parser.add_argument('--num_layers', type=int, default=1, help='number of layers in lstm')
args = parser.parse_args()
main(args)