# Data Preprocessing

### <font color='blue'>Cerebro Imports</font>

In [None]:
from cerebro.etl.etl_spec import ETLSpec
from cerebro.experiment import Experiment
from cerebro.mop.sub_epoch_spec import SubEpochSpec

### <font color='blue'> Initialize Data Preprocessing </font>

In [None]:
class CocoETLSpec(ETLSpec):
    def __init__(self):
        from torchvision import transforms
        from coco_proc.vocabulary import Vocabulary

        vocab_threshold = 5
        self.max_caption_len = 55
        self.miscellaneous_path = "/data/cerebro_data_storage/miscellaneous"
        self.is_feature_download = [False, True, False, False, False, False]
        annotations_file = self.miscellaneous_path + "/captions_train2017.json"
        self.train_vocab = Vocabulary(vocab_threshold, annotations_file=annotations_file, vocab_from_file=False)
        
        self.train_img_transform = transforms.Compose([ 
                transforms.Resize(256),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                    (0.229, 0.224, 0.225))])
        
        self.valid_img_transform = transforms.Compose([ 
                transforms.Resize(256),                          
                transforms.CenterCrop(224),
                transforms.ToTensor(),                           
                transforms.Normalize((0.485, 0.456, 0.406),      
                                    (0.229, 0.224, 0.225))])

    def initialize_worker(self):
        try:
            import nltk
            nltk.download("punkt")
        except:
            pass
    
    def row_preprocessor(self, row, mode, object_dir):
        import nltk
        import torch
        from PIL import Image

        vocab = self.train_vocab    
        max_caption_len = self.max_caption_len
        
        if mode == "train":
            img_transform = self.train_img_transform
        else:
            img_transform = self.valid_img_transform
        
        # reading input features and converting to tensor
        input_image_path = object_dir + "/" + str(row["file_name"])
        image = Image.open(input_image_path).convert("RGB")
        image_tensor = img_transform(image)
        
        # reading output features and converting to tensor
        output_caption = row["captions"]
        tokens = nltk.tokenize.word_tokenize(str(output_caption).lower())
        caption = []
        caption.append(vocab(vocab.start_word))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab(vocab.end_word))
        
        # padding
        nremaining = max_caption_len - len(tokens)
        if nremaining > 0:
            for i in range(nremaining):
                caption.append(vocab(vocab.end_word))
        
        caption_tensor = torch.Tensor(caption).long()

        return image_tensor, caption_tensor

### <font color='blue'> Define the model training and validation functions </font>

In [None]:
class CocoTrainingSpec(SubEpochSpec):
    def __init__(self):
        from coco_proc.vocabulary import Vocabulary
        
        vocab_threshold = 5
        self.miscellaneous_path = "/data/cerebro_data_storage/miscellaneous"
        annotations_file = self.miscellaneous_path + "/captions_train2017.json"
        self.train_vocab = Vocabulary(vocab_threshold, annotations_file=annotations_file, vocab_from_file=False)

    def get_actual_annotations(self, annotations_path):
        import json
        data_json = None
        with open(annotations_path) as f:
            data_json = json.load(f)
        annotations = {}
        annotations_list = data_json['annotations']
        for i in annotations_list:
            if not i["image_id"] in annotations:
                annotations[i["image_id"]] = []
            annotations[i["image_id"]].append(i["caption"])
        return annotations

    def word_list(self, word_idx_list, vocab):
        word_list = []
        for i in range(len(word_idx_list)):
            vocab_id = word_idx_list[i]
            word = vocab.idx2word[vocab_id]
            if word == vocab.end_word:
                break
            if word != vocab.start_word:
                word_list.append(word)
        return word_list

    def initialize_worker(self):
        try:
            import nltk
            nltk.download("punkt")
        except:
            pass

    def create_model_components(self, hyperparams):
        import torch
        from coco_proc.model import EncoderCNN, DecoderRNN

        # Obtain hyperparameters
        vocab_size = len(self.train_vocab)
        learning_rate = hyperparams["learning_rate"]
        embed_size = hyperparams["embed_size"]
        hidden_size = hyperparams["hidden_size"]

        # Create the models
        encoder = EncoderCNN(embed_size)
        decoder = DecoderRNN(embed_size, hidden_size, vocab_size)
        models = [encoder, decoder]

        # Specify the learnable parameters of the model
        params = list(decoder.parameters()) + list(encoder.embed.parameters()) + list(encoder.bn.parameters())
        
        # Define the optimizer
        optimizer = torch.optim.Adam(params=params, lr=learning_rate)

        return models, optimizer

    def train(self, models, optimizer, checkpoint, dataloader, hyperparams, input_device, output_device):
        import math
        import torch
        import numpy as np
        import torch.nn as nn

        # get hyperparams
        vocab_size = len(self.train_vocab)
        batch_size = hyperparams["batch_size"]

        # Define the loss function
        criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()

        # get models
        encoder, decoder = models

        states = checkpoint("load")
        if states:
            encoder.load_state_dict(states["encoder"])
            decoder.load_state_dict(checkpoint['decoder'])
            optimizer.load_state_dict(checkpoint['optimizer'])

        encoder.train()
        decoder.train()

        # Start time for every 100 steps
        i_step = 0
        total_loss = 0.0
        subepoch_total_step = math.ceil(len(dataloader.dataset) / batch_size)
        train_metrics = []

        for batch in dataloader:
            images, captions = batch[0].to(input_device), batch[1].to(output_device)
            features = encoder(images)
            outputs = decoder(features, captions)
            loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            stats_dict = {
                "loss": loss.item(),
                "perplexity": np.exp(loss.item())
            }

            stats = "Train step [%d/%d], Loss: %.4f, Perplexity: %5.4f" \
                        % (i_step, subepoch_total_step, loss.item(), np.exp(loss.item()))
            print("\r" + stats, end="")

            train_metrics.append(stats_dict)

            i_step += 1

        checkpoint("save", {"encoder": encoder.state_dict(),
                    "decoder": decoder.state_dict(),
                    "optimizer" : optimizer.state_dict()
                   })

        return train_metrics
    
    def test(self, models, checkpoint, dataloader, hyperparams, input_device, output_device):
        import math
        import nltk
        import torch
        import numpy as np
        import torch.nn as nn
        from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

        # get hyperparams
        vocab = self.train_vocab
        batch_size = hyperparams["batch_size"]

        val_annotations_path = self.miscellaneous_path + "/captions_val2017.json"
        annotations_valid = self.get_actual_annotations(val_annotations_path)

        # get models
        encoder, decoder = models

        # Define the loss function
        criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()

        states = checkpoint("load")
        if states:
            encoder.load_state_dict(states["encoder"])
            decoder.load_state_dict(checkpoint['decoder'])

        # Switch to validation mode
        encoder.eval()
        decoder.eval()

        # Initialize smoothing function
        smoothing = SmoothingFunction()

        # Keep track of validation loss and Bleu-4 score
        total_loss = start_loss
        total_bleu_4 = start_bleu

        start_loss=0.0
        start_bleu = 0.0
        subepoch_total_step = math.ceil(len(dataloader.dataset) / batch_size)
        
        start_step=1
        with torch.no_grad():
            # Obtain the batch
            for batch in dataloader:
                images, captions, row_ids = batch[0].to(input_device), batch[1].to(output_device), batch[2]

                # Pass the inputs through the CNN-RNN model
                features = encoder(images)
                outputs = decoder(features, captions).to("cpu")

                # move outputs back to CPU
                captions = captions.to("cpu")

                # Calculate the total Bleu-4 score for the batch
                batch_bleu_4 = 0.0
                for i in range(len(outputs)):
                    predicted_ids = []
                    for scores in outputs[i]:
                        predicted_ids.append(scores.argmax().item())
                    predicted_word_list = self.word_list(predicted_ids, vocab)
                    # caption_word_list = self.word_list(captions[i].numpy(), vocab)

                    tokenized_references = [nltk.tokenize.word_tokenize(str(caption).lower())
                                           for caption in annotations_valid[row_ids[i].item()]]
                    batch_bleu_4 += sentence_bleu(tokenized_references, 
                                                   predicted_word_list, 
                                                   smoothing_function=smoothing.method1)
                total_bleu_4 += batch_bleu_4 / len(outputs)

                # Calculate the batch loss
                loss = criterion(outputs.view(-1, len(vocab)), captions.view(-1))
                total_loss += loss.item()

                # Get validation statistics
                stats = "Val step [%d/%d], Loss: %.4f, Perplexity: %5.4f, Batch Bleu-4: %.4f" \
                        % (start_step, subepoch_total_step,
                           loss.item(), np.exp(loss.item()), batch_bleu_4 / len(outputs))
                print("\r" + stats, end="")

                start_step += 1

            test_metrics = {
                "total_epoch_loss": total_loss / subepoch_total_step,
                "total_bleu_4": total_bleu_4 / subepoch_total_step
            }

            return test_metrics

### <font color='blue'> Run Cerebro </font>

In [None]:
num_epochs = 2
param_grid = {
    'learning_rate': [1e-2, 1e-3],
    'embed_size': [256],
    'hidden_size': [256],
    'batch_size': [128]
}

In [None]:
experiment = Experiment()
coco_etl_spec = CocoETLSpec()
coco_training_spec = CocoTrainingSpec()

In [None]:
experiment.run_etl(coco_etl_spec, fraction=0.1)

In [None]:
experiment.run_fit(coco_training_spec, param_grid, num_epochs)