In [None]:
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
import torch
from PIL import Image

class ComplexImageCaptioningModel:
    def __init__(self, model_name="nlpconnect/vit-gpt2-image-captioning"):
        self.model = VisionEncoderDecoderModel.from_pretrained(model_name)
        self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        self.max_length = 16
        self.num_beams = 4
        self.gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}

    def preprocess_images(self, image_paths):
        images = []
        for image_path in image_paths:
            i_image = Image.open(image_path)
            if i_image.mode != "RGB":
                i_image = i_image.convert(mode="RGB")
            images.append(i_image)
        return images

    def generate_captions(self, images):
        pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
        pixel_values = pixel_values.to(self.device)

        output_ids = self.model.generate(pixel_values, **self.gen_kwargs)

        preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        preds = [pred.strip() for pred in preds]
        return preds

    def predict_step(self, image_paths):
        try:
            images = self.preprocess_images(image_paths)
            captions = self.generate_captions(images)
            return captions
        except Exception as e:
            print(f"Error during prediction: {e}")
            return []

# Example usage
model_instance = ComplexImageCaptioningModel()

result = model_instance.predict_step(['/content/images-1.jpeg'])
print(result)




['a man standing in the middle of a field']
