In [None]:
import yaml
import torch
import timm
import requests
import email
import imaplib
import time
import re
import io
from datetime import datetime
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from megadetector.detection import run_detector
from megadetector.visualization import visualization_utils as vis_utils

# Load settings from configuration file
with open('../config.yaml') as file:
    config = yaml.safe_load(file)

IMAP_HOST = config['imap_config']['host']
EMAIL_USER = config['imap_config']['user']
EMAIL_PASS = config['imap_config']['password']
TELEGRAM_BOT_TOKEN = config['telegram_config']['bot_token']
TELEGRAM_CHAT_ID = '-1002249589791' # replace with config after tests

# Detection and Classification Model Settings
MODEL_PATH_DETECTOR = '../models/md_v5a.0.0.pt'
MODEL_PATH_CLASSIFIER = '../models/deepfaune-vit_large_patch14_dinov2.lvd142m.pt'
BACKBONE = 'vit_large_patch14_dinov2'
ANIMAL_CLASSES = ["badger", "ibex", "red deer", "chamois", "cat", "goat", "roe deer", "dog", "squirrel", "equid", "genet", "hedgehog", "lagomorph", "wolf", "lynx", "marmot", "micromammal", "mouflon", "sheep", "mustelid", "bird", "bear", "nutria", "fox", "wild boar", "cow"]
DETECTOR_CLASSES = ["animal", "human", "vehicle"]
species_of_interest = {"wild boar", "bear", "wolf", "roe deer", "red deer"}
DETECTION_THRESHOLD = 0.05
CLASSIFICATION_THRESHOLD = 0.05

# Load the MegaDetector model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
detector = run_detector.load_detector(MODEL_PATH_DETECTOR)

# move to utils.py

class Classifier:
    """Image classifier for animal species."""
    def __init__(self):
        self.model = timm.create_model(BACKBONE, pretrained=False, num_classes=len(ANIMAL_CLASSES))
        state_dict = torch.load(MODEL_PATH_CLASSIFIER, map_location=torch.device(device))['state_dict']
        self.model.load_state_dict({k.replace('base_model.', ''): v for k, v in state_dict.items()})
        self.transforms = transforms.Compose([
            transforms.Resize((518, 518), interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self.model.eval()

    def predict(self, image):
        """Predict the species of an animal in the image."""
        img_tensor = self.transforms(image).unsqueeze(0)
        with torch.no_grad():
            output = self.model(img_tensor)
            probabilities = torch.nn.functional.softmax(output, dim=1)
            top_p, top_class = probabilities.topk(1, dim=1)
            return ANIMAL_CLASSES[top_class.item()], top_p.item()
        
def image_to_bytes(image):
    """Convert an image to bytes."""
    byte_arr = io.BytesIO()
    image.save(byte_arr, format='JPEG')
    byte_arr.seek(0)
    return byte_arr

def detector(image):
    """Run the MegaDetector on an image and return detections above the threshold."""
    processed_image = vis_utils.load_image(image_to_bytes(image))
    result = model.generate_detections_one_image(processed_image)
    detections_above_threshold = [d for d in result['detections'] if d['conf'] > DETECTION_THRESHOLD]
    return detections_above_threshold



# Initialize classifier
classifier = Classifier()

def load_font(font_path, font_size):
    """Load the specified font."""
    try:
        return ImageFont.truetype(font_path, font_size)
    except IOError:
        return ImageFont.load_default()

def annotate_image(image, detections, classifier, species_of_interest, font):
    """Annotate the image with detection results."""
    draw = ImageDraw.Draw(image)
    detection_results = []
    species_counts = {species: 0 for species in ANIMAL_CLASSES}
    highest_confidences = {species: 0 for species in ANIMAL_CLASSES}

    for i, detection in enumerate(detections):
        bbox = detection['bbox']
        left, top, width, height = bbox
        left_resized = int(left * image.width)
        top_resized = int(top * image.height)
        right_resized = int((left + width) * image.width)
        bottom_resized = int((top + height) * image.height)

        cropped_image = image.crop((left_resized, top_resized, right_resized, bottom_resized))
        species, confidence = classifier.predict(cropped_image)
        if confidence > CLASSIFICATION_THRESHOLD:
            detection_results.append((species, confidence, bbox))
            species_counts[species] += 1
            highest_confidences[species] = max(highest_confidences[species], confidence)

            # Draw bounding box and label
            draw.rectangle([left_resized, top_resized, right_resized, bottom_resized], outline="red", width=4)
            if species in species_of_interest:
                label = f"{species.title()}: {int(confidence * 100)}%"
                text_bbox = draw.textbbox((left_resized, top_resized), label, font=font)
                draw.rectangle(text_bbox, fill="red")
                draw.text((left_resized, top_resized), label, fill="white", font=font)

    return detection_results, species_counts, highest_confidences

def generate_caption(species_of_interest_detected, species_counts, highest_confidences, capture_time):
    """Generate a caption for the image."""
    if species_of_interest_detected:
        highest_confidence_species = max(species_of_interest_detected, key=species_of_interest_detected.get)
        count_highest_confidence_species = species_of_interest_detected[highest_confidence_species]
        caption = f"{count_highest_confidence_species} {highest_confidence_species.upper()} DETECTED\n"
    else:
        highest_confidence_species = None
        caption = "NO SPECIES OF INTEREST DETECTED\n"

    caption += "Location: TBC\n"
    caption += f"Time: {capture_time}\n"
    caption += "------------------------\n"
    caption += "Further information:\n"

    sorted_by_detections = sorted([(species, count, highest_confidences[species]) for species, count in species_counts.items() if count > 0], key=lambda x: x[1], reverse=True)
    for species, count, max_conf in sorted_by_detections[:2]:
        caption += f"{species.title()}: {count} detections, highest confidence {int(max_conf * 100)}%\n"

    return caption, highest_confidence_species

def save_image_with_metadata(image, highest_confidence_species):
    """Save the image with metadata."""
    timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    species_name = highest_confidence_species.replace(' ', '_') if highest_confidence_species else "no_species_of_interest"
    file_path = f"../data/{timestamp}-{species_name}.jpg"
    image.save(file_path)
    return file_path

def process_single_image(image):
    """Process a single image for detection and classification."""
    saved_image = image.copy()
    annotated_image = image.copy()
    font = load_font("Carlito-Bold.ttf", 100)

    # Detect objects in the image
    detections = detector(image)
    print(f"{len(detections)} detections.")

    if detections:
        # Annotate the image
        detection_results, species_counts, highest_confidences = annotate_image(annotated_image, detections, classifier, species_of_interest, font)

        # Generate caption
        species_of_interest_detected = {species: species_counts[species] for species in species_of_interest if species_counts[species] > 0}
        try:
            exif_data = image._getexif()
            capture_time_exif = exif_data.get(36867)
            capture_time = datetime.strptime(capture_time_exif, "%Y:%m:%d %H:%M:%S").strftime("%Y-%m-%d %H:%M:%S")
        except (AttributeError, TypeError):
            capture_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        caption, highest_confidence_species = generate_caption(species_of_interest_detected, species_counts, highest_confidences, capture_time)

        # Save the image
        save_image_with_metadata(saved_image, highest_confidence_species)

        return annotated_image, caption
    else:
        return None, None

def send_photo_to_telegram(bot_token, chat_id, photo, caption):
    """Send photo with caption to Telegram."""
    url = f"https://api.telegram.org/bot{bot_token}/sendPhoto"
    with io.BytesIO() as buf:
        photo.save(buf, format='JPEG')
        buf.seek(0)
        files = {'photo': buf}
        params = {'chat_id': chat_id, 'caption': caption}
        response = requests.post(url, files=files, data=params)
        response.raise_for_status()
        print("Alert sent.")

def download_image_from_url(url):
    """Download an image from a URL."""
    try:
        response = requests.get(url)
        response.raise_for_status()
        return Image.open(io.BytesIO(response.content))
    except requests.RequestException as e:
        print(f"Error downloading image from {url}: {str(e)}")
        return None

def extract_images_from_email(msg):
    """Extract images from an email message."""
    image_list = []
    if msg.is_multipart():
        for part in msg.walk():
            content_type = part.get_content_type()
            content_disposition = part.get('Content-Disposition', '')

            if content_type.startswith('image/') and 'attachment' in content_disposition:
                image_data = part.get_payload(decode=True)
                image = Image.open(io.BytesIO(image_data))
                image_list.append(image)
            elif content_type == 'text/html':
                html_body = part.get_payload(decode=True).decode()
                image_urls = re.findall(r'<img src="(https?://[^"]+)"', html_body)
                for url in image_urls:
                    image = download_image_from_url(url)
                    if image:
                        image_list.append(image)
    return image_list

def check_emails():
    """Check emails for new messages with images."""
    mail = imaplib.IMAP4_SSL(IMAP_HOST)
    mail.login(EMAIL_USER, EMAIL_PASS)
    mail.select('inbox')
    typ, data = mail.search(None, 'UNSEEN')
    for num in data[0].split():
        typ, data = mail.fetch(num, '(RFC822)')
        msg = email.message_from_bytes(data[0][1])
        images = extract_images_from_email(msg)
        for index, image in enumerate(images):
            processed_image, caption = process_single_image(image)
            if processed_image:
                send_photo_to_telegram(TELEGRAM_BOT_TOKEN, TELEGRAM_CHAT_ID, processed_image, caption)
    mail.logout()

if __name__ == "__main__":
    print(f"Monitoring {EMAIL_USER} for new messages...")
    while True:
        try:
            time.sleep(1)
            check_emails()
        except KeyboardInterrupt:
            print("Interrupted by user")
            break
        except Exception as e:
            print(f"An error occurred: {e}")
            print(f"\nMonitoring {EMAIL_USER} for new messages...")
            continue









