Skip to content

Commit b579373

Browse files
author
Mohammad Hasan
authored
Fixed Test error when encoder(image).
When we load an image for the sample, we have to make sure that the image has three color channel (RGB) because it might be grayscale. So we should convert it for sampling.
1 parent 4896cef commit b579373

File tree

1 file changed

+2
-2
lines changed
  • tutorials/03-advanced/image_captioning

1 file changed

+2
-2
lines changed

tutorials/03-advanced/image_captioning/sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1515

1616
def load_image(image_path, transform=None):
17-
image = Image.open(image_path)
17+
image = Image.open(image_path).convert('RGB')
1818
image = image.resize([224, 224], Image.LANCZOS)
1919

2020
if transform is not None:
@@ -78,4 +78,4 @@ def main(args):
7878
parser.add_argument('--hidden_size', type=int , default=512, help='dimension of lstm hidden states')
7979
parser.add_argument('--num_layers', type=int , default=1, help='number of layers in lstm')
8080
args = parser.parse_args()
81-
main(args)
81+
main(args)

0 commit comments

Comments
 (0)