In [12]:
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from torchvision import transforms, models
from PIL import Image
import torch.nn as nn
import nltk
from nltk.stem import WordNetLemmatizer
import os
# Download wordnet data if not already present (quietly)
nltk.download('wordnet', quiet=True)

True

In [2]:
# Function to extract animal mentions from text using the NER model
def extract_animals(text):
    words = text.split()
    encoding = tokenizer(
        words,
        is_split_into_words=True,
        return_offsets_mapping=True,
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="pt"
    )
    input_ids = encoding["input_ids"]
    attention_mask = encoding["attention_mask"]
    word_ids = encoding.word_ids(batch_index=0)
    tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze().tolist())

    with torch.no_grad():
        outputs = ner_model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = torch.argmax(outputs.logits, dim=2).squeeze().cpu().numpy()

    word_predictions = {}
    for idx, word_id in enumerate(word_ids):
        if word_id is None:
            continue
        token = tokens[idx].replace("##", "")
        if word_id not in word_predictions:
            word_predictions[word_id] = {"text": token, "predictions": [predictions[idx]]}
        else:
            word_predictions[word_id]["text"] += token
            word_predictions[word_id]["predictions"].append(predictions[idx])

    results = []
    # If the label for a word equals 1, consider it an animal mention
    for word_id, info in word_predictions.items():
        if info["predictions"][0] == 1:
            results.append(info["text"])
    return results



In [3]:
# Function to load the image classification model
def load_classifier(model_path="animal_classifier_final.pth"):
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    class_names = checkpoint['class_names']
    
    class AnimalClassifier(nn.Module):
        def __init__(self, num_classes):
            super(AnimalClassifier, self).__init__()
            self.resnet = models.resnet50(weights=None)
            num_ftrs = self.resnet.fc.in_features
            self.resnet.fc = nn.Sequential(
                nn.Linear(num_ftrs, 512),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(256, num_classes)
            )
        
        def forward(self, x):
            return self.resnet(x)
    
    model = AnimalClassifier(num_classes=len(class_names))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    return model, class_names



In [4]:
# Function to predict the class of an image using the classifier
def predict(image_path, classifier, class_names, threshold=0.5):
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    image = Image.open(image_path).convert("RGB")
    image_tensor = transform(image).unsqueeze(0)
    
    with torch.no_grad():
        outputs = classifier(image_tensor)
        _, preds = torch.max(outputs, 1)
        probability = torch.nn.functional.softmax(outputs, dim=1)[0]
    
    pred_class = class_names[preds[0]]
    pred_prob = probability[preds[0]].item()
    
    if pred_prob < threshold:
        return 'other'
    else:
        return pred_class


In [5]:
# Dictionary of animal synonyms
animal_synonyms = {
    "owl": [
        "owl", "hoot-owl", "nightbird", "screech-owl", "barn-owl", "horned-owl", 
        "night-raptor", "strix", "bubo", "wise-bird", "hooting-bird"
    ],
    "frog": [
        "frog", "toad", "ranid", "polliwog", "hylid", "tree-frog", "bullfrog", 
        "spring-peeper", "croaker", "leaper", "ribbit"
    ],
    "cat": [
        "cat", "kitten", "feline", "kitty", "pussycat", "puss", "moggy", "tabby", 
        "grimalkin", "tomcat", "gib", "meow-meow", "tigger"
    ],
    "ladybug": [
        "ladybug", "ladybird", "ladybeetle", "coccinellid", "spot-bug", "red-beetle", 
        "luck-bug", "sun-beetle", "polka-dot-bug", "aphid-hunter", "dome-bug"
    ],
    "dove": [
        "dove", "pigeon", "rock-dove", "turtledove", "columbidae", "mourning-dove", 
        "ringdove", "white-dove", "peace-bird", "soft-wing", "love-bird"
    ],
    "bee": [
        "bee", "honeybee", "bumblebee", "worker-bee", "queen-bee", "drone", 
        "apis", "nectar-hunter", "pollen-carrier", "buzz-bug", "hive-dweller"
    ],
    "beaver": [
        "beaver", "castor", "river-beaver", "dam-builder", "flat-tail", "pond-engineer", 
        "woodworker", "gnawer", "wetland-rodent", "bucktooth", "fur-chewer"
    ],
    "flamingo": [
        "flamingo", "flamant", "pink-bird", "wader", "long-legs", "roseate", "scarlet-crane", 
        "tropical-stork", "lagoon-dancer", "curved-beak", "salt-marsh-bird"
    ],
    "elk": [
        "elk", "wapiti", "moose", "stag", "hart", "red-deer", "cervid", 
        "great-deer", "timber-buck", "forest-giant", "antler-king"
    ],
    "fox": [
        "fox", "vixen", "tod", "red-fox", "silver-fox", "arctic-fox", "furry-trickster", 
        "swift-tail", "shadow-hunter", "sly-one", "bushytail"
    ],
    "squirrel": [
        "squirrel", "tree-rat", "nut-hoarder", "acorn-gatherer", "bushy-tail", "rodent-leaper", 
        "tree-dweller", "chatterbox", "grey-squirrel", "red-squirrel", "ground-squirrel"
    ],
    "goose": [
        "goose", "gander", "gaggle", "wild-goose", "snow-goose", "greylag", "honker", 
        "barnyard-bird", "long-neck", "waterfowl", "migrator"
    ],
    "bear": [
        "bear", "bruin", "ursus", "grizzly", "black-bear", "polar-bear", "cave-dweller", 
        "honey-lover", "big-paw", "forest-giant", "fur-giant"
    ]
}


In [6]:
class AnimalPipeline:
    def __init__(self, ner_model_path="animal_ner_model", classifier_model_path="animal_classifier_final.pth"):
        # Load the NER model and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(ner_model_path)
        self.ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_path)
        self.ner_model.eval()
        # Make the model and tokenizer available to the extract_animals function via globals
        global tokenizer, ner_model
        tokenizer = self.tokenizer
        ner_model = self.ner_model
        
        # Load the image classification model
        self.classifier, self.class_names = load_classifier(classifier_model_path)
        
        # Initialize the animal synonyms dictionary
        self.animal_synonyms = animal_synonyms
        self.lemmatizer = WordNetLemmatizer()
        
    def extract_animals_from_text(self, text):
        """Виділяє назви тварин із тексту за допомогою NER."""
        extracted_animals = extract_animals(text)

        lemmatized_animals = [self.lemmatizer.lemmatize(animal.lower(), pos='n') for animal in extracted_animals]
        return lemmatized_animals
    
    def predict_from_image(self, image_path, threshold=0.5):
        """Predict the animal in the image using the classifier."""
        return predict(image_path, self.classifier, self.class_names, threshold)
    
    def is_same_animal(self, image_path, text, threshold=0.5):
        """
        Returns True if the image classifier's result matches 
        (using the synonyms dictionary) the animals extracted from the text.
        If the classifier returns 'other' or no match is found, it returns False.
        """
        predicted_animal = self.predict_from_image(image_path, threshold)
        extracted_animals = self.extract_animals_from_text(text)
        
        # If the classifier did not recognize an animal from the list
        if predicted_animal == 'other':
            return False
        
        # Get the synonyms for the predicted animal (converted to lowercase)
        synonyms = [syn.lower() for syn in self.animal_synonyms.get(predicted_animal, [])]
        # Check if at least one of the extracted animal words matches a synonym
        for animal in extracted_animals:
            if animal.lower() in synonyms:
                return True
        return False


In [11]:
pipeline = AnimalPipeline()
result = pipeline.is_same_animal("Animals_images_testing/cat.jpg", "The is jumping on the lily.")
print(result)

False
