# Deep Learning - MCH2
Fachdozent: Martin Melchior     
Student: Manuel Schwarz   
HS23

Dieses Notebook bearbeitet die Mini-Challenge 2 des Moduls Deep Learning (del).   
Die Performance der Modelle wurde mit **wandb.ai** aufgezeichnet und kann [hier](https://wandb.ai/manuel-schwarz/del-mc2/workspace?workspace=user-manuel-schwarz) eingesehen werden.  

<div class="alert alert-block alert-info">
<b>Aufgabenstellung:</b> Eine Blaue Box beschreibt die Aufgabe aus der Aufgabenstellung 'SGDS_DEL_MC1.pdf' 
</div>

<div class="alert alert-block alert-success">
<b>Antworte:</b> Eine Grüne Box beschreibt die Bearbeitung / Reflektion der Aufgabenstellung
</div>

In [None]:
import os
import copy
import time
import torch
import wandb
import random
import torchvision
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.optim as optim
from PIL import Image
from tqdm import tqdm 
from datetime import datetime
from torch.optim import lr_scheduler
from torch.utils.data import Dataset
from sklearn.model_selection import KFold
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torchvision import datasets, models, transforms

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device: ', device)

# sound
import time
import winsound
import datetime

### Aufbau Modellierung und Daten

<div class="alert alert-block alert-info">

Überlege Dir, welche Modell-Architektur Sinn machen könnte. Mindestens zwei Modell-Varianten sollen aufgebaut werden, die miteinander verglichen werden sollen.

</div>

<div class="alert alert-block alert-success">
Für die del-MC2 Challenge wird das Modell vom Paper Vinyals et al `Show and Tell: A Neural Image Caption Generator` nachgebaut. Das Paper entwickelte ein Modell welches für Bilder eine Bildbeschreibung erstellt. Für die verwendeten Daten wird das `Flickr 8k` Datenset verwendet. 
</div>

### Daten Flickr 8k lesen

In [None]:
images_folder = './data/Images'
captions_file = './data/captions.txt'

In [None]:
pd_captions = pd.read_csv('./data/captions.txt', sep='\t', header=None)
pd_captions.columns = ['full_caption']
pd_captions[['image_name', 'caption']] = pd_captions['full_caption'].str.split(',', n=1, expand=True)
pd_captions.to_csv('./data/pd_captions.csv', index=False)
pd_captions.drop('full_caption', axis=1, inplace=True)
pd_captions.head(10)

Im `caption.txt` File ist der Bildnamen und die Bildbeschreibung (caption) hinterlegt. Pro Bild stehen fünf Captions zur Verfügung.

In [None]:
image_id = 5

example_image_path = f'{images_folder}/{pd_captions.image_name[image_id]}'
example_caption1 = pd_captions.caption[image_id+0]
example_caption2 = pd_captions.caption[image_id+1]
example_caption3 = pd_captions.caption[image_id+2]
example_caption4 = pd_captions.caption[image_id+3]
example_caption5 = pd_captions.caption[image_id+4]
image = Image.open(example_image_path)

plt.imshow(image)
plt.title(f'{example_caption1} \n {example_caption2} \n {example_caption3} \n {example_caption4} \n {example_caption5}')
plt.axis('off') 
plt.show()

### Aufteilung in Trainings- und Testdaten

In [None]:
print(f'Anzahl Captions: {len(pd_captions)}')
unique_images = pd_captions.image_name.unique()
print(f'Anzahl Bilder: {len(unique_images)}')

unique_images = list(pd_captions.image_name.unique())
train_images = random.sample(unique_images, k=int(len(unique_images) * 0.8))
test_images = list(set(unique_images) - set(train_images))

print(f'Länge Trainingsset: {len(train_images)}')
print(f'Länge Testset: {len(test_images)}')

pd_train_set = pd_captions[pd_captions.image_name.isin(train_images)]
pd_test_set = pd_captions[pd_captions.image_name.isin(test_images)]
pd_train_set.to_csv('./data/train_captions.csv', index=False)
pd_test_set.to_csv('./data/test_captions.csv', index=False)

print(f'Länge Trainingsset: {len(pd_train_set)}')
print(f'Länge Testset: {len(pd_test_set)}')


In [None]:
train_set = pd_captions = pd.read_csv('./data/train_captions.csv')
test_set = pd_captions = pd.read_csv('./data/test_captions.csv')

train_set.head(10)

In [None]:
image_id = 5

example_image_path = f'{images_folder}/{train_set.image_name[image_id]}'
example_caption1 = train_set.caption[image_id+0]
example_caption2 = train_set.caption[image_id+1]
example_caption3 = train_set.caption[image_id+2]
example_caption4 = train_set.caption[image_id+3]
example_caption5 = train_set.caption[image_id+4]
image = Image.open(example_image_path)

plt.imshow(image)
plt.title(f'{example_caption1} \n {example_caption2} \n {example_caption3} \n {example_caption4} \n {example_caption5}')
plt.axis('off') 
plt.show()

### Erstellen des Dataloaders

In [None]:
images_folder = './data/Images'
captions_file = './data/captions.txt'

class Flickr8kDataset(Dataset):
    def __init__(self, images_folder, captions_file, transform=None):
        self.images_folder = images_folder
        self.transform = transform

        with open(captions_file, 'r') as file:
            lines = file.readlines()

        self.captions = {}
        for line in lines:
            parts = line.strip().split('\t')
            if len(parts) == 2:
                image_id, caption = parts
                if image_id[:-2] not in self.captions:
                    self.captions[image_id[:-2]] = []
                self.captions[image_id[:-2]].append(caption)

        self.image_ids = list(self.captions.keys())

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = os.path.join(self.images_folder, image_id)
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        captions = self.captions[image_id]
        return image, captions

transformations = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.ToTensor()
])

flickr8k_dataset = Flickr8kDataset(
    images_folder=images_folder,
    captions_file=captions_file,
    transform=transformations
)

flickr8k_dataloader = DataLoader(dataset=flickr8k_dataset, batch_size=32, shuffle=False, num_workers=0)

for images, captions in flickr8k_dataloader:
    print(images.shape)
    print(captions[0]) 
    break 