In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
from tqdm import tqdm
import pandas as pd


In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

#Load pre-trained model and tokenizer
model_name = "dslim/bert-base-NER"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)

# Check the original label mappings from the pre-trained model
original_id2label = model.config.id2label
original_label2id = model.config.label2id
#print("Original labels:", original_id2label)


In [8]:
def tokenize_and_align_labels(text, word_labels, tokenizer, label_map, max_len=128, label_all_tokens=False):
    # Split the text into words
    words = text.split()
    encoding = tokenizer(
        words,
        is_split_into_words=True,
        return_offsets_mapping=True,
        padding="max_length",
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )
    
    word_ids = encoding.word_ids(batch_index=0)
    labels = []
    previous_word_idx = None
    for word_idx in word_ids:
        if word_idx is None:
            labels.append(-100)
        else:
            if word_idx != previous_word_idx:
                labels.append(label_map.get(word_labels[word_idx], label_map["O"]))
            else:

                if label_all_tokens:
                    label = label_map.get(word_labels[word_idx], label_map["O"])
                    if label == label_map.get("B-ANIMAL"):
                        label = label_map.get("I-ANIMAL")
                    labels.append(label)
                else:
                    labels.append(-100)
            previous_word_idx = word_idx
    encoding["labels"] = torch.tensor(labels)
    return encoding
    

class AnimalNERDataset(Dataset):
    def __init__(self, texts, tags, tokenizer, max_len=128):
        self.texts = texts
        self.tags = tags 
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        self.label_map = {
            "O": original_label2id["O"],          # Outside any entity
            "B-ANIMAL": original_label2id["B-MISC"],# Beginning of animal entity
            "I-ANIMAL": original_label2id["I-MISC"] # Inside of animal entity
        }
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        word_tags = self.tags[idx]  

        encoding = tokenize_and_align_labels(
            text, 
            word_tags, 
            self.tokenizer, 
            self.label_map, 
            max_len=self.max_len, 
            label_all_tokens=True
        )
    
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        
        return encoding

def create_sample_data():
    df = pd.read_csv("/Users/anastasiiaserputko/Test/Task2/animal_sentences.csv")
        
    texts = df['sentence'].tolist()

    animal_list = [
    "cat", "feline", "kitty", "kitten", "tomcat", "puss",
    "bear", "grizzly", "bruin", "cub", "ursine",
    "goose", "gander", "gosling", "waterfowl",
    "squirrel", "chipmunk", "tree-dweller",
    "fox", "vixen", "reynard", "tod",
    "elk", "moose", "wapiti", "stag",
    "flamingo", "wader", "pinkbird",
    "owl", "hooter",
    "frog", "toad", "amphibian", "croaker",
    "beaver", "dam-builder",
    "bee", "honeybee", "bumblebee", "drone",
    "dove", "pigeon", "columbidae", "columbid",
    "ladybug", "ladybird", "beetle", "coccinellid"
    ]
        
    tags = []
    for text in texts:
        words = text.split()
        word_tags = []
        for word in words:
            lower_word = word.lower().strip(".,!?")
            if lower_word in animal_list:
                word_tags.append("B-ANIMAL")
            else:
                word_tags.append("O")
        tags.append(word_tags)
    
    
    split_index = int(0.8 * len(texts))  
    train_texts = texts[:split_index]
    train_tags = tags[:split_index]
    val_texts = texts[split_index:]
    val_tags = tags[split_index:]
    
    return train_texts, train_tags, val_texts, val_tags

def train_model(model, train_dataloader, val_dataloader, epochs=2):
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    
    # Calculate total training steps
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=0,
        num_training_steps=total_steps
    )
    
    model.to(device)
    
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        
        # Training loop
        model.train()
        train_loss = 0
        for batch in tqdm(train_dataloader, desc="Training"):
            optimizer.zero_grad()
            
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            loss = outputs.loss
            train_loss += loss.item()
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
        
        avg_train_loss = train_loss / len(train_dataloader)
        print(f"Average training loss: {avg_train_loss}")
        
        # Validation loop
        model.eval()
        val_loss = 0
        predictions = []
        true_labels = []
        
        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc="Validation"):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)
                
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                
                loss = outputs.loss
                val_loss += loss.item()

                logits = outputs.logits
                pred = torch.argmax(logits, dim=2)
                
                active_mask = labels != -100
                

                predictions.extend(pred[active_mask].cpu().numpy())
                true_labels.extend(labels[active_mask].cpu().numpy())
        
        avg_val_loss = val_loss / len(val_dataloader)
        print(f"Validation loss: {avg_val_loss}")
        
        precision, recall, f1, _ = precision_recall_fscore_support(
            true_labels, predictions, average="weighted"
        )
        accuracy = accuracy_score(true_labels, predictions)
        
        print(f"Validation Accuracy: {accuracy:.4f}")
        print(f"Validation Precision: {precision:.4f}")
        print(f"Validation Recall: {recall:.4f}")
        print(f"Validation F1: {f1:.4f}")
    
    return model

def extract_animals(text, model, tokenizer):
    model.eval()
    words = text.split()
    encoding = tokenizer(
        words,
        is_split_into_words=True,
        return_offsets_mapping=True,
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="pt"
    )
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)
    word_ids = encoding.word_ids(batch_index=0)  # отримуємо індекси слів для кожного токена
    tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze().tolist())
    
    offset_mapping = encoding.get("offset_mapping", None)
    if offset_mapping is not None:
        offset_mapping = offset_mapping.squeeze().tolist()
    else:
        offset_mapping = [(None, None)] * len(tokens)  # Заглушка, щоб уникнути помилок

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = torch.argmax(outputs.logits, dim=2).squeeze().cpu().numpy()
        
    #print(f"\nText: {text}")
    #print(f"Tokens: {tokens}")
    #print(f"Word IDs: {word_ids}")
    #print(f"Offset Mapping: {offset_mapping}")
    #print(f"Predictions: {predictions}")

    #print("\nToken Predictions:")
    for token, word_id, offset, pred in zip(tokens, word_ids, offset_mapping, predictions):
        label_name = original_id2label.get(pred, "O")
        #print(f"Token: {token.ljust(10)} | Word ID: {str(word_id).ljust(3)} | Offset: {offset} | Prediction: {label_name}")

    word_predictions = {}
    for idx, word_id in enumerate(word_ids):
        if word_id is None:
            continue
        token = tokens[idx]

        token_clean = token.replace("##", "")
        if word_id not in word_predictions:
            word_predictions[word_id] = {"text": token_clean, "predictions": [predictions[idx]]}
        else:
            word_predictions[word_id]["text"] += token_clean
            word_predictions[word_id]["predictions"].append(predictions[idx])
    
    results = []
    for word_id, info in word_predictions.items():
        if info["predictions"][0] == original_label2id["B-MISC"]:
            results.append(info["text"])
            
    return results

# Main function to run the entire pipeline
def main():
    # Create sample data
    train_texts, train_tags, val_texts, val_tags = create_sample_data()
    
    # Create datasets and dataloaders
    train_dataset = AnimalNERDataset(train_texts, train_tags, tokenizer)
    val_dataset = AnimalNERDataset(val_texts, val_tags, tokenizer)
    
    train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=2)
    
    # Fine-tune the model
    fine_tuned_model = train_model(model, train_dataloader, val_dataloader, epochs=2)
    
    # Save the fine-tuned model
    fine_tuned_model.save_pretrained("animal_ner_model")
    tokenizer.save_pretrained("animal_ner_model")
    
    # Test on sample sentences
    test_sentences = [
        "There is a lion in the picture.",
        "The image shows a beautiful dolphin jumping.",
        "I can see both a giraffe and a zebra in this photo.",
        "This appears to be a picture of a small rabbit."
    ]
    
    for sentence in test_sentences:
        animals = extract_animals(sentence, fine_tuned_model, tokenizer)
        print(f"Text: {sentence}")
        print(f"Detected animals: {animals}")
        print()


main()

Epoch 1/2


Training: 100%|█████████████████████████████| 2600/2600 [16:07<00:00,  2.69it/s]


Average training loss: 0.00042790482715964536


Validation: 100%|█████████████████████████████| 650/650 [00:43<00:00, 14.83it/s]


Validation loss: 0.11852206967749217
Validation Accuracy: 0.9907
Validation Precision: 0.9916
Validation Recall: 0.9907
Validation F1: 0.9909
Epoch 2/2


Training: 100%|█████████████████████████████| 2600/2600 [15:40<00:00,  2.77it/s]


Average training loss: 1.1891512759868524e-06


Validation: 100%|█████████████████████████████| 650/650 [00:42<00:00, 15.27it/s]


Validation loss: 0.12479027056866439
Validation Accuracy: 0.9907
Validation Precision: 0.9916
Validation Recall: 0.9907
Validation F1: 0.9909
Text: There is a lion in the picture.
Detected animals: ['lion']

Text: The image shows a beautiful dolphin jumping.
Detected animals: ['dolphin']

Text: I can see both a giraffe and a zebra in this photo.
Detected animals: ['giraffe', 'zebra']

Text: This appears to be a picture of a small rabbit.
Detected animals: ['rabbit.']



In [13]:
def test_model_on_examples(model, tokenizer):
    test_sentences = [
    "The cat is sitting on the windowsill.",
    "The feline stretched lazily in the sun.",
    "A kitten was playing with a ball of yarn.",
    
    "The bear caught a fish from the river.",
    "A large grizzly wandered through the forest.",
    "The cub followed its mother closely.",
    
    "A goose honked loudly near the pond.",
    "The gander led its flock across the field.",
    "A gosling swam behind its mother.",
    
    "A squirrel is collecting nuts for winter.",
    "The chipmunk darted into its burrow.",
    "A ground squirrel peeked out from behind a tree.",
    
    "The fox ran swiftly through the meadow.",
    "A vixen and her cubs played in the field.",
    "The red fox is known for its cunning nature.",
    
    "An elk stood majestically in the clearing.",
    "The stag had an impressive set of antlers.",
    "A wapiti grazed peacefully in the meadow.",
    
    "A flamingo stood gracefully on one leg.",
    "The wader dipped its beak into the water.",
    "A bright pinkbird preened its feathers.",
    
    "An owl hooted softly in the night.",
    "The horned owl watched from its perch.",
    "A barn owl soared silently over the field.",
    
    "The frog jumped into the pond.",
    "A toad sat on a lily pad.",
    "The tree frog clung to the branch.",
    
    "A beaver built a dam in the stream.",
    "The dam-builder gnawed on a piece of wood.",
    "A muskrat swam near the shore.",
    
    "A bee buzzed around the flowers.",
    "The bumblebee collected nectar from a blossom.",
    "A honeybee returned to the hive with pollen.",
    
    "A dove perched on the rooftop.",
    "The pigeon cooed softly in the park.",
    "A squab was learning to fly.",
    
    "A ladybug landed on my hand.",
    "The ladybird crawled on a leaf.",
    "A little beetle scurried across the petal."

    "A tree-dweller landed on my hand.",
    "A little spotted bug across the petal."
    ]
    
    print("\n Testing model on various sentences...\n")
    
    for sentence in test_sentences:
        detected_animals = extract_animals(sentence, model, tokenizer)
        print(f"Text: {sentence}")
        print(f"Detected Animals: {detected_animals}")
        print("="*50)
        
test_model_on_examples(model, tokenizer)


 Testing model on various sentences...

Text: The cat is sitting on the windowsill.
Detected Animals: ['cat']
Text: The feline stretched lazily in the sun.
Detected Animals: ['feline']
Text: A kitten was playing with a ball of yarn.
Detected Animals: ['kitten', 'yarn.']
Text: The bear caught a fish from the river.
Detected Animals: ['bear']
Text: A large grizzly wandered through the forest.
Detected Animals: ['grizzly']
Text: The cub followed its mother closely.
Detected Animals: ['cub']
Text: A goose honked loudly near the pond.
Detected Animals: ['goose']
Text: The gander led its flock across the field.
Detected Animals: ['gander']
Text: A gosling swam behind its mother.
Detected Animals: ['gosling']
Text: A squirrel is collecting nuts for winter.
Detected Animals: ['squirrel']
Text: The chipmunk darted into its burrow.
Detected Animals: ['chipmunk']
Text: A ground squirrel peeked out from behind a tree.
Detected Animals: ['squirrel']
Text: The fox ran swiftly through the meadow.
De