# Deep Learning - MCH2
Fachdozent: Martin Melchior     
Student: Manuel Schwarz   
HS23

Dieses Notebook bearbeitet die Mini-Challenge 2 des Moduls Deep Learning (del).   
Die Performance der Modelle wurde mit **wandb.ai** aufgezeichnet und kann [hier](https://wandb.ai/manuel-schwarz/del-mc2/workspace?workspace=user-manuel-schwarz) eingesehen werden.  

<div class="alert alert-block alert-info">
<b>Aufgabenstellung:</b> Eine Blaue Box beschreibt die Aufgabe aus der Aufgabenstellung 'SGDS_DEL_MC1.pdf' 
</div>

<div class="alert alert-block alert-success">
<b>Antworte:</b> Eine Grüne Box beschreibt die Bearbeitung / Reflektion der Aufgabenstellung
</div>

In [None]:
import torch

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device: ', device)

In [None]:
import os
import copy
import time
import torch
import wandb
import random
import spacy  # conda install -c conda-forge spacy + python -m spacy download en_core_web_sm
import torchvision
import torchtext
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.optim as optim
from PIL import Image
from tqdm import tqdm 
from datetime import datetime
from spacy.symbols import ORTH
from collections import Counter
from torch.optim import lr_scheduler
from torch.utils.data import Dataset
from sklearn.model_selection import KFold
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torchvision import datasets, models, transforms

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device: ', device)

# sound
import time
import winsound
import datetime

### Aufbau Modellierung und Daten

<div class="alert alert-block alert-info">

Überlege Dir, welche Modell-Architektur Sinn machen könnte. Mindestens zwei Modell-Varianten sollen aufgebaut werden, die miteinander verglichen werden sollen.

</div>

<div class="alert alert-block alert-success">
Für die del-MC2 Challenge wird das Modell vom Paper Vinyals et al `Show and Tell: A Neural Image Caption Generator` nachgebaut. Das Paper entwickelte ein Modell welches für Bilder eine Bildbeschreibung erstellt. Für die verwendeten Daten wird das `Flickr 8k` Datenset verwendet. 
</div>

### Hilfsfunktionen

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
set_seed()

def play_sound(typ=0):
    # play 'finish' sound
    if typ==0:
        winsound.PlaySound('../01_Dokumentation/win_sounds/beep.wav', winsound.SND_ASYNC)
    if typ==1:
        winsound.PlaySound('../01_Dokumentation/win_sounds/beep2.wav', winsound.SND_ASYNC)


### Daten Flickr 8k lesen

In [None]:
images_folder = './data/Images'
captions_file = './data/captions.txt'

In [None]:
pd_captions = pd.read_csv('./data/captions.txt', sep='\t', header=None)
pd_captions.columns = ['full_caption']
pd_captions[['image_name', 'caption']] = pd_captions['full_caption'].str.split(',', n=1, expand=True)
pd_captions.to_csv('./data/pd_captions.csv', index=False)
pd_captions.drop('full_caption', axis=1, inplace=True)
pd_captions.head(10)

Im `caption.txt` File ist der Bildnamen und die Bildbeschreibung (caption) hinterlegt. Pro Bild stehen fünf Captions zur Verfügung.

In [None]:
image_id = 5

example_image_path = f'{images_folder}/{pd_captions.image_name[image_id]}'
example_caption1 = pd_captions.caption[image_id+0]
example_caption2 = pd_captions.caption[image_id+1]
example_caption3 = pd_captions.caption[image_id+2]
example_caption4 = pd_captions.caption[image_id+3]
example_caption5 = pd_captions.caption[image_id+4]
image = Image.open(example_image_path)

plt.imshow(image)
plt.title(f'{example_caption1} \n {example_caption2} \n {example_caption3} \n {example_caption4} \n {example_caption5}')
plt.axis('off') 
plt.show()

### Untersuchen der Bilddaten

In [None]:
images_folder_path = './data/Images/'

resolutions = []
for image_filename in os.listdir(images_folder_path):
    if image_filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
        image_path = os.path.join(images_folder_path, image_filename)
        with Image.open(image_path) as img:
            resolutions.append(img.size)

In [None]:
dimension_labels = [f"{w}x{h}" for w, h in resolutions]
resolution_counts = Counter(dimension_labels)
sorted_resolution_counts = dict(resolution_counts.most_common(50))

plt.figure(figsize=(10, 3))
plt.bar(sorted_resolution_counts.keys(), sorted_resolution_counts.values(), color='skyblue')

plt.title('Verteilung der Bildauflösungen von Flickr 8k (Top 50)')
plt.xlabel('Dimension (b x w)')
plt.ylabel('Anzahl Bilder')
plt.xticks(rotation=90)
plt.show()

In [None]:
widths, heights = zip(*resolutions)

plt.figure(figsize=(10, 3))

plt.subplot(1, 2, 1)
plt.boxplot(widths, patch_artist=True, boxprops=dict(facecolor='grey', color='black'),
            whiskerprops=dict(color='black'), capprops=dict(color='black'),
            medianprops=dict(color='red'))
plt.title('Verteilung der Bildbreiten')
plt.ylabel('Pixel')
plt.xticks([1], ['b'])

plt.subplot(1, 2, 2) 
plt.boxplot(heights, patch_artist=True, boxprops=dict(facecolor='grey', color='black'),
            whiskerprops=dict(color='black'), capprops=dict(color='black'),
            medianprops=dict(color='red'))
plt.title('Verteilung der Bildhöhen')
plt.ylabel('Pixel')
plt.xticks([1], ['h'])

plt.tight_layout()
plt.show()

### Untersuchen der Bildbeschreibungen

Textdaten werden folgend in Tokens konvergiert:
[doku torchtext](http://man.hubwiz.com/docset/torchtext.docset/Contents/Resources/Documents/index.html)


In [None]:
# !pip show torchtext
# Version: 0.6.0

In [None]:
spacy_en = spacy.load('en_core_web_sm')

special_cases = [("<start>", [{ORTH: "<start>"}]), ("<end>", [{ORTH: "<end>"}]), ("<pad>", [{ORTH: "<pad>"}])]
for case in special_cases:
    spacy_en.tokenizer.add_special_case(*case)

def tokenize_en(caption, lower_text=False):
    if lower_text:
        return [tok.text.lower() for tok in spacy_en.tokenizer(caption)]
    else:
        return [tok.text for tok in spacy_en.tokenizer(caption)]    

print(f'Test tokenize: {example_caption1}')
tokens = tokenize_en(example_caption1)
tokens

In [None]:
# liste mit allen vorhandenen Tokens 
token_series = pd_captions['caption'].apply(tokenize_en).explode()
count_token = token_series.value_counts()
count_token.index.name = 'token'
count_token = count_token.reset_index()
# count_token

In [None]:
num_token = 50
plt.figure(figsize = (10, 4))
plt.bar(count_token.head(num_token).token, count_token.head(num_token).caption, color='skyblue')
plt.title(f'Total Tokens: {len(count_token)} Tokens, dargestellt {num_token} Tokens')
plt.xticks(rotation=90)
plt.xlabel('Token')
plt.ylabel('Vorkommen')
plt.show()

In [None]:
def tokenize_en_len(caption, lower_text=False):
    if lower_text:
        return len([tok.text.lower() for tok in spacy_en.tokenizer(caption)])
    else:
        return len([tok.text for tok in spacy_en.tokenizer(caption)])

token_series = pd_captions['caption'].apply(tokenize_en_len)
# token_series

In [None]:
plt.figure(figsize = (10, 3))
token_mean = token_series.mean()
token_std = token_series.std() / 2
plt.hist(token_series, color='skyblue', bins=40)
plt.axvline(token_mean, color='red', alpha=0.6, label=f'mean {token_mean:0.2f}')
plt.axvspan(token_mean-token_std, token_mean+token_std, color='grey', alpha=0.2, label=f'std {token_std:0.2f}')
plt.suptitle('Verteilung der Anzahl Tokens je Caption')
plt.title(f'min bei: {token_series.min()}, max bei: {token_series.max()}', fontsize=8)
plt.xlabel('Anzahl Tokens')
plt.ylabel('Vorkommen')
plt.legend()
plt.show()

### Bearbeitung der Captions für Modelling

- Im Paper werden Tokens die weniger als fünf mal vorkommen entfernt
- Im Paper werden `start` und `end` Token eingeführt, diese sollen dem Modell helfen zu erkennen, wann der Beginn der Generierung einer Beschreibung ist und wann sie beendet werden sollte.

In [None]:
min_num_token = 5
count_token_filtered = count_token[count_token.caption >= min_num_token]
print(f'Total Tokens {len(count_token)}')
print(f'Anzahl Tokens die mehr als {min_num_token} vorkommen: {len(count_token_filtered)}')

In [None]:
START_TOKEN = "<start>"
END_TOKEN = "<end>"

pd_caption_mod = pd_captions.copy()
pd_caption_mod['caption'] = pd_caption_mod['caption'].apply(lambda x: f"{START_TOKEN} {x} {END_TOKEN}")
pd_captions.to_csv('./data/pd_captions_mod.csv', index=False)

print(pd_captions['caption'][0])
print(pd_captions['caption'][1])
print(pd_captions['caption'][2])
print()
print(pd_caption_mod['caption'][0])
print(pd_caption_mod['caption'][1])
print(pd_caption_mod['caption'][2])


### Aufteilung in Trainings- und Testdaten

In [None]:
print(f'Anzahl Captions: {len(pd_caption_mod)}')
unique_images = pd_caption_mod.image_name.unique()
print(f'Anzahl Bilder: {len(unique_images)}')

unique_images = list(pd_caption_mod.image_name.unique())
train_images = random.sample(unique_images, k=int(len(unique_images) * 0.8))
test_images = list(set(unique_images) - set(train_images))

print(f'Länge Trainingsset: {len(train_images)}')
print(f'Länge Testset: {len(test_images)}')

pd_train_set = pd_caption_mod[pd_caption_mod.image_name.isin(train_images)]
pd_test_set = pd_caption_mod[pd_caption_mod.image_name.isin(test_images)]
pd_train_set.to_csv('./data/train_captions.csv', index=False)
pd_test_set.to_csv('./data/test_captions.csv', index=False)

print(f'Länge Trainingsset: {len(pd_train_set)}')
print(f'Länge Testset: {len(pd_test_set)}')


In [None]:
train_set = pd.read_csv('./data/train_captions.csv')
test_set = pd.read_csv('./data/test_captions.csv')

train_set.head(10)

In [None]:
image_id = 5

example_image_path = f'{images_folder}/{train_set.image_name[image_id]}'
example_caption1 = train_set.caption[image_id+0]
example_caption2 = train_set.caption[image_id+1]
example_caption3 = train_set.caption[image_id+2]
example_caption4 = train_set.caption[image_id+3]
example_caption5 = train_set.caption[image_id+4]
image = Image.open(example_image_path)

plt.imshow(image)
plt.title(f'{example_caption1} \n {example_caption2} \n {example_caption3} \n {example_caption4} \n {example_caption5}')
plt.axis('off') 
plt.show()

### Bearbeitung der Captions für Modelling (Trainingsset)
- caption in matrix ablegen, gleiche länge der captions

In [None]:
token_series = train_set['caption'].apply(tokenize_en_len)
print(f'maximale caption länge: {token_series.max()} Tokens')

In [None]:
def build_vocab(tokenized_captions, min_freq):
    # Zähle die Häufigkeit der Tokens in allen Captions
    token_counts = Counter(token for caption in tokenized_captions for token in caption)

    # Erstelle das Vokabular nur mit Tokens, die min_freq oder mehr Mal vorkommen
    vocab = {
        "<pad>": 0,
        "<start>": 1,
        "<end>": 2
    }
    token_id = 3
    for token, count in token_counts.items():
        if count >= min_freq:
            vocab[token] = token_id
            token_id += 1
    return vocab

def create_matrices_from_captions(train_set):
    captions = train_set['caption']
    tokenized_captions = [tokenize_en(caption, lower_text=True) for caption in captions]

    # Erstellen des Vokabular mit einer Mindesthäufigkeit von x
    vocab = build_vocab(tokenized_captions, min_freq=5)

    # Vokabular verwenden, um Ihre Captions in Indizes umzuwandeln
    indexed_captions = [[vocab.get(token, vocab["<pad>"]) for token in caption] for caption in tokenized_captions]

    caption_tensors = [torch.tensor(caption) for caption in indexed_captions]

    # Bestimmen der maximale Länge für das Padding
    max_length = max(len(caption) for caption in caption_tensors)

    # Padding hinzufügen, damit alle Captions die gleiche Länge haben
    padded_captions = pad_sequence(caption_tensors, batch_first=True, padding_value=vocab["<pad>"])

    return padded_captions, vocab

def indices_to_words(tensor_indices, vocab, rm_pad=True):
    index_to_word = {index: word for word, index in vocab.items()}    
    words = [index_to_word.get(index.item(), '<unk>') for index in tensor_indices]
    if rm_pad:
        words = [word for word in words if word != '<pad>']    
    return words


In [None]:
padded_captions, vocab = create_matrices_from_captions(train_set)
print(f'Matrix dim: {padded_captions.shape}')
print(f'Länge Wörterbuch: {len(vocab)}')
padded_captions[:2]


### Erstellen des Dataloaders

https://nlp.stanford.edu/projects/glove/

In [None]:
# from torchtext.vocab import GloVe
# glove = GloVe(name='6B', dim=300)
# glove_embeddings = glove.vectors
# print(glove_embeddings.shape)


In [None]:
padded_captions, vocab = create_matrices_from_captions(train_set)

In [None]:
def load_glove_embeddings(path='./data/glove.6B.300d.txt'):
    with open(path, 'r', encoding='utf-8') as f:
        vocab = {}
        for line in f.readlines():
            values = line.split()
            word = values[0]
            vector = torch.tensor([float(val) for val in values[1:]], dtype=torch.float)
            vocab[word] = vector
    return vocab

glove_embeddings = load_glove_embeddings()


In [None]:
def create_embedding_matrix(vocab, glove_embeddings, embedding_dim = 300):
    embedding_matrix = torch.zeros((len(vocab)+2, embedding_dim))
    for word, idx in vocab.items():
        if word in glove_embeddings:
            embedding_matrix[idx] = glove_embeddings[word]
        else:
            embedding_matrix[idx] = torch.randn(embedding_dim)  # Zufälliger Vektor für Wörter, die nicht in GloVe sind
    return embedding_matrix

embedding_matrix = create_embedding_matrix(vocab, glove_embeddings)
embedding_matrix.shape

In [None]:
class FlickrDataset(Dataset):
    def __init__(self, csv_file_name, root_dir, vocab, embedding_matrix, transform=None):
        self.captions_frame = pd.read_csv(csv_file_name)
        self.root_dir = root_dir
        self.vocab = vocab
        self.embedding_matrix = embedding_matrix
        self.transform = transform

    def __len__(self):
        return len(self.captions_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.captions_frame.iloc[idx, 0])
        image = Image.open(img_name)
        
        caption = self.captions_frame.iloc[idx,1]
        tokenized_caption = tokenize_en(caption, lower_text=True)

        caption_indices = [self.vocab.get(token, self.vocab['<pad>']) for token in tokenized_caption]
        
        # Umwandeln der Liste von Indizes in einen Tensor
        caption_indices_tensor = torch.tensor(caption_indices, dtype=torch.long)

        # Extrahieren der Embeddings für die Indizes
        caption_embeddings = torch.stack([self.embedding_matrix[idx] for idx in caption_indices])

        if self.transform:
            image = self.transform(image)

        return image, caption, caption_indices_tensor, caption_embeddings

    def collate_fn(self, batch):
        images, caption, caption_indices, caption_embeddings = zip(*batch)

        # Pad die caption_indices und caption_embeddings
        caption_indices_padded = pad_sequence(caption_indices, batch_first=True, padding_value=self.vocab['<pad>'])
        caption_embeddings_padded = pad_sequence(caption_embeddings, batch_first=True, padding_value=self.vocab['<pad>'])

        images = torch.stack(images)  # Stapeln der Bilder zu einem Tensor

        return images, caption, caption_indices_padded, caption_embeddings_padded
    


In [None]:
train_set = FlickrDataset(
    csv_file_name='./data/train_captions.csv',
    root_dir='./data/Images',
    vocab=vocab,
    embedding_matrix=embedding_matrix,
    transform=None
)

# Testen des Datensets
image, caption, caption_indices, caption_embeddings = train_set[0]

plt.figure(figsize = (10, 5))
plt.imshow(image)
plt.suptitle(f'{caption}', fontsize=10)
plt.title(f'Länge caption indices: {len(caption_indices)}, Länge caption_embeddings: {len(caption_embeddings)}', fontsize=8)
plt.axis('off') 
plt.show()

### Erstellen und Testen des Dataloaders

In [None]:
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

denormalize = transforms.Normalize(
    mean=[-m / s for m, s in zip([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])],
    std=[1 / s for s in [0.229, 0.224, 0.225]]
)

transformations = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    normalize
])

train_set = FlickrDataset(
    csv_file_name='./data/train_captions.csv',
    root_dir='./data/Images',
    vocab=vocab,
    embedding_matrix=embedding_matrix,
    transform=transformations
)

test_set = FlickrDataset(
    csv_file_name='./data/test_captions.csv',
    root_dir='./data/Images',
    vocab=vocab,
    embedding_matrix=embedding_matrix,
    transform=transformations
)

train_dataloader = DataLoader(train_set, batch_size=4, shuffle=True, collate_fn=train_set.collate_fn)
test_dataloader = DataLoader(test_set, batch_size=4, shuffle=False, collate_fn=test_set.collate_fn)

In [None]:
# Testen des Dataloader, prüfen ob die caption_empeddings pro batch gleich lang sind
for i_batch, (image, caption, caption_indices, caption_embeddings) in enumerate(train_dataloader):
    print(f'Länge caption_indices 1: {len(caption_indices[0])}, Länge caption_embeddings 1: {len(caption_embeddings[0])}')
    print(f'Länge caption_indices 2: {len(caption_indices[1])}, Länge caption_embeddings 2: {len(caption_embeddings[1])}')
    print(f'Länge caption_indices 3: {len(caption_indices[2])}, Länge caption_embeddings 3: {len(caption_embeddings[2])}')
    print(f'Länge caption_indices 4: {len(caption_indices[3])}, Länge caption_embeddings 4: {len(caption_embeddings[3])}')
    print()
    print('Example Image 1/4')
    plt.figure(figsize = (10, 4))
    img_tensor_denorm  = denormalize(image[0])
    img_pil = transforms.ToPILImage()(img_tensor_denorm)
    plt.imshow(img_pil)
    plt.suptitle(caption[0])
    plt.title(f'wörter from caption_indices: {indices_to_words(caption_indices[0], vocab)}', fontsize=7)
    plt.axis('off')
    plt.show()
    break

### Modellierung Modell nach Paper `Show and Tell`

In [None]:
class ImageCaptioningModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout_prob=0.5, glove_em=None):
        super(ImageCaptioningModel, self).__init__()
        
        # Laden des vortrainierten ResNet-50 ohne den letzten Layer
        resnet = models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.fc = nn.Linear(resnet.fc.in_features, embedding_dim)
        
        # Einbettungs-Layer für die Captions
        if glove_em is not None:
            self.embedding = nn.Embedding.from_pretrained(glove_em, freeze=True)
            print('using glove')
        else:
            self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # LSTM für die Caption-Generierung
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.dropout = nn.Dropout(dropout_prob)
        # Layer, um die Wort-Indizes vorherzusagen
        self.linear = nn.Linear(hidden_dim, vocab_size)
        self.batch_norm = nn.BatchNorm1d(embedding_dim)
    
    def forward(self, images, captions):
        # CNN-Teil
        with torch.no_grad():  # Gradienten für ResNet nicht berechnen
            features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.fc(features)

        features = self.batch_norm(features)
        features = self.dropout(features)
        
        # Embedding und LSTM-Teil
        embeddings = self.embedding(captions)
        embeddings = self.dropout(embeddings)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        
        return outputs


### Training Modell

In [None]:
def train_modell(config, model, dataloader, optimizer, criterion, epochs, device, test_batch=False):
    set_seed(config['set_seed'])
    model.to(device)
    model.train()

    epoch_losses = []
    batch_losses = []

    # Initialize wandb
    if config['write_wandb']: 
        model_name = f"{config['name']}-{config['epochs']}-epochs-{config['start_time']}"
        wandb.init(
            project="del-mc2",
            entity='manuel-schwarz',
            group=config['group'],
            name= model_name,
            tags=str(config['tags']) + (' is_test_batch' if config['is_test_batch'] else ''),
            config=config
        )
        wandb.watch(model)

    for epoch in range(epochs):
        ep_loss = []
        loop = tqdm(enumerate(dataloader), total=len(dataloader), leave=False)
        for i, (images, captions, caption_indices, glove_embeddings) in loop:
            if test_batch and i > 1:
                break

            # if config['use_glove_emb']:
            #     images, captions_emb2 = images.to(device), glove_embeddings.to(device)
            #     # captions_emb2 = captions_emb2.long()
            # else:
            #     images, captions_emb = images.to(device), caption_indices.to(device)

            images, captions_emb = images.to(device), caption_indices.to(device)

            # outputs = model(images, captions_emb2[:, :-1])  # Exclude the <end> token
            outputs = model(images, captions_emb[:, :-1])  # Exclude the <end> token
            # targets = caption_indices[:, 1:].contiguous().view(-1)  # Exclude the <start> token
            targets = captions_emb[:, :].contiguous().view(-1)

            output_shaped = outputs.view(-1, outputs.size(-1))

            # loss = criterion(outputs.view(-1, outputs.size(-1)), targets)
            loss = criterion(output_shaped, targets)
            batch_losses.append(loss.item())
            ep_loss.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        epoch_losses.append(np.mean(ep_loss))
        # print(f'Epoch {epoch}    Loss: {np.mean(ep_loss)}')

        if config['write_wandb']:
            wandb.log({
                "train loss epoch": np.mean(ep_loss)
                })
            
    if config['write_wandb']: 
        wandb.finish()
        time.sleep(5)  # wait for wandb.finish

    print('Modell finished!')
    play_sound(1)
    return epoch_losses, batch_losses

def plot_loss_model(epoch_losses, name='-'):
    plt.figure(figsize=(10,5))
    plt.plot(epoch_losses, marker='o', color='skyblue', label='Training loss per epoch')
    plt.title(f'Training Loss per Epoch (model: {name})')
    plt.xlabel('Epoche')
    plt.ylabel('Loss')
    # plt.legend()
    plt.grid(True)
    plt.show()

def save_model(model, config):
	torch.save({
		'state_dict': model.state_dict(),
	}, f'./models/{config["start_time"]}_{config["name"]}_epochs_{config["epochs"]}.tar')
    
def load_model(model, model_name, config):
	data = torch.load(f'./models/{model_name}.tar', map_location=device)
	model.load_state_dict(data['state_dict'])
	return model


In [None]:
from datetime import datetime

# Hyperparameters example
config = {
    "name": "CNN_LSTM", 
    "epochs": 2,   
    "train_batch_size": 4, 
    "test_batch_size": 2,
    "dataset": "flickr8k",
    "lr": 0.1, 
    "optimizer": 'SGD',
    "loss_func": 'CrossEntropyLoss',
    "image_size": 256,
    "is_test_batch": False,
    "start_time": datetime.now().strftime("%d.%m.%Y_%H%M"),
    "num_workers": 0,
    "dropout": 0.5,
    "set_seed": 42,
    'vocab_size': len(vocab)+1,
    'embedding_dim': 300,
    'hidden_dim': 512,
    'num_layers': 1,
    'write_wandb':True,
    'group': 'cpu test',
    'tags': 'tests',
    'use_glove_emb': True,
    'save_model': True
}

train_dataloader = DataLoader(
    train_set, 
    batch_size=config['train_batch_size'], 
    shuffle=True, 
    collate_fn=train_set.collate_fn
)
test_dataloader = DataLoader(
    test_set, 
    batch_size=config['test_batch_size'], 
    shuffle=False, 
    collate_fn=test_set.collate_fn
)

model1 = ImageCaptioningModel(
    vocab_size = config['vocab_size'], 
    embedding_dim = config['embedding_dim'], 
    hidden_dim = config['hidden_dim'], 
    num_layers = config['num_layers'],
    dropout_prob= config['dropout'],
    glove_em = embedding_matrix if config['use_glove_emb'] else None
)

# optimizer = optim.Adam(model1.parameters(), lr=0.001)
optimizer = torch.optim.SGD(model1.parameters(), lr=config['lr'])
criterion = nn.CrossEntropyLoss()

epoch_losses, batch_losses = train_modell(
    config,
    model1,
    train_dataloader,
    optimizer, 
    criterion, 
    epochs=5, 
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    test_batch=config['is_test_batch']
)

if config['save_model']:
    save_model(model1, config)
    print('Modell saved!')

plot_loss_model(epoch_losses, 'model1')

### Modell Vorhersagen

In [None]:
model_name = '30.12.2023_1630_CNN_LSTM_epochs_2'
loaded_model = load_model(model1, model_name, config)

