# COMS W4705 - Homework 3
## Conditioned LSTM Language Model for Image Captioning
Daniel Bauer <bauer@cs.columbia.edu>

Follow the instructions in this notebook step-by step. Much of the code is provided (especially in part I, II, and III), but some sections are marked with **todo**. Make sure to complete all these sections. 

Specifically, you will build the following components: 

* Part I (14pts): Create encoded representations for the images in the flickr dataset using a pretrained image encoder(ResNet)
* Part II (14pts): Prepare the input caption data.
* Part III (24pts): Train an LSTM language model on the caption portion of the data and use it as a generator. 
* Part IV (24pts): Modify the LSTM model to also pass a copy of the input image in each timestep. 
* Part V (24pts): Implement beam search for the image caption generator.

Access to a GPU is required for this assignment. If you have a recent mac, you can try using mps. Otherwise, I recommend renting a GPU instance through a service like vast.ai or lambdalabs. Google Colab can work in a pinch, but you would have to deal with quotas and it's somewhat easy to lose unsaved work. 

### Getting Started 

There are a few required packages. 

In [None]:
import os
import PIL # Python Image Library

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.models import ResNet18_Weights

In [None]:
if torch.cuda.is_available(): 
    DEVICE = 'cuda'
elif torch.mps.is_available():
    DEVICE = 'mps'
else: 
    DEVICE = 'cpu'
    print("You won't be able to train the RNN decoder on a CPU, unfortunately.")
print(DEVICE)

### Access to the flickr8k data

We will use the flickr8k data set, described here in more detail: 

> M. Hodosh, P. Young and J. Hockenmaier (2013) "Framing Image Description as a Ranking Task: Data, Models and Evaluation Metrics", Journal of Artificial Intelligence Research, Volume 47, pages 853-899 http://www.jair.org/papers/paper3994.html 

N.B.: Usage of this data is limited to this homework assignment. If you would like to experiment with the dataset beyond this course, I suggest that you submit your own download request here (it's free): https://forms.illinois.edu/sec/1713398


The data is available in a Google Cloud storage bucket here:
https://storage.googleapis.com/4705_sp25_hw3/hw3data.zip

In [None]:
#Download the data.
!wget https://storage.googleapis.com/4705_sp25_hw3/hw3data.zip

In [None]:
#Then unzip the data 
!unzip hw3data.zip

Alternative option if you are using Colab (though using wget, as shown above, works on Colab as well):
* The data is available on google drive. You can access the folder here:
https://drive.google.com/drive/folders/1sXWOLkmhpA1KFjVR0VjxGUtzAImIvU39?usp=sharing
* Sharing is only enabled for the lionmail domain. Please make sure you are logged into Google Drive using your Columbia UNI. I will not be able to respond to individual sharing requests from your personal account. 

* Once you have opened the folder, click on "Shared With Me", then select the hw5data folder, and press shift+z. This will open the "add to drive" menu. Add the folder to your drive. (This will not create a copy, but just an additional entry point to the shared folder). 

The following variable should point to the location where the data is located. 

In [None]:
#this is where you put the name of your data folder.
#Please make sure it's correct because it'll be used in many places later.
MY_DATA_DIR="hw3data"

## Part I: Image Encodings (14 pts)

The files Flickr_8k.trainImages.txt Flickr_8k.devImages.txt Flickr_8k.testImages.txt, contain a list of training, development, and test images, respectively. Let's load these lists. 

In [None]:
def load_image_list(filename):
    with open(filename,'r') as image_list_f: 
        return [line.strip() for line in image_list_f]    

In [None]:
FLICKR_PATH="hw3data/"

In [None]:
train_list = load_image_list(os.path.join(FLICKR_PATH, 'Flickr_8k.trainImages.txt'))
dev_list = load_image_list(os.path.join(FLICKR_PATH,'Flickr_8k.devImages.txt'))
test_list = load_image_list(os.path.join(FLICKR_PATH,'Flickr_8k.testImages.txt'))

Let's see how many images there are

In [None]:
len(train_list), len(dev_list), len(test_list)

Each entry is an image filename.

In [None]:
dev_list[20]

The images are located in a subdirectory.  

In [None]:
IMG_PATH = os.path.join(FLICKR_PATH, "Flickr8k_Dataset")

We can use PIL to open and display the image:

In [None]:
image = PIL.Image.open(os.path.join(IMG_PATH, dev_list[20]))
image

### Preprocessing

We are going to use an off-the-shelf pre-trained image encoder, the ResNet-18 network. Here is more detail about this model (not required for this project): 

> Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun; Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016, pp. 770-778 
> https://openaccess.thecvf.com/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf

The model was initially trained on an object recognition task over the ImageNet1k data. The task is to predict the correct class label for an image, from a set of 1000 possible classes.

To feed the flickr images to ResNet, we need to perform the same normalization that was applied to the training images. More details here: https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html

In [None]:
from torchvision import transforms 

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

The resulting images, after preprocessing, are (3,224,244) tensors, where the first dimension represents the three color channels, R,G,B).

In [None]:
processed_image = preprocess(image)
processed_image.shape

To the ResNet18 model, the images look like this: 

In [None]:
transforms.ToPILImage()(processed_image)

### Image Encoder
Let's instantiate the ReseNet18 encoder. We are going to use the pretrained weights available in torchvision.

In [None]:
img_encoder = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)

In [None]:
img_encoder.eval()

This is a prediction model,so the output is typically a softmax-activated vector representing 1000 possible object types. Because we are interested in an encoded representation of the image we are just going to use the second-to-last layer as a source of image encodings. Each image will be encoded as a vector of size 512. 

We will use the following hack: remove the last layer, then reinstantiate a Squential model from the remaining layers. 

In [None]:
lastremoved = list(img_encoder.children())[:-1]
img_encoder = torch.nn.Sequential(*lastremoved).to(DEVICE) # also send it to GPU memory
img_encoder.eval()

Let's try the encoder.

In [None]:
def get_image(img_name): 
    image = PIL.Image.open(os.path.join(IMG_PATH, img_name))
    return preprocess(image)

In [None]:
preprocessed_image = get_image(train_list[0])
encoded = img_encoder(preprocessed_image.unsqueeze(0).to(DEVICE)) # unsqueeze required to add batch dim (3,224,224) becomes (1,3,224,224)
encoded.shape                      

The result isn't quite what we wanted: The final representation is actually a 1x1 "image" (the first dimension is the batch size). 
We can just grab this one pixel:

In [None]:
encoded = encoded[:,:,0,0] #this is our final image encoded
encoded.shape 

**TODO:** Because we are just using the pretrained encoder, we can simply encode all the images in a preliminary step. We will store them in one big tensor (one for each dataset, train, dev, test). This will save some time when training the conditioned LSTM because we won't have to recompute the image encodings with each training epoch. We can also save the tensors to disk so that we never have to touch the bulky image data again.

Complete the following function that should take a list of image names and return a tensor of size [n_images, 512] (where each row represents one image). 

For example `encode_imates(train_list)` should return a [6000,512] tensor. 

In [None]:
def encode_images(image_list): 
    pass #TODO.... 
           
enc_images_train = encode_images(train_list)
enc_images_train.shape

We can now save this to disk:

In [None]:
torch.save(enc_images_train, open('encoded_images_train.pt','wb'))

It's a good idea to save the resulting matrices, so we do not have to run the encoder again. 

## Part II Text (Caption) Data Preparation (14 pts)

Next, we need to load the image captions and generate training data for the language model. We will train a text-only model first.

### Reading image descriptions

**TODO**: Write the following function that reads the image descriptions from the file `filename` and returns a dictionary in the following format. Take a look at the file `Flickr8k.token.txt` for the format of the input file. 
The keys of the dictionary should be image filenames. Each value should be a list of 5 captions. Each caption should be a list of tokens.  

The captions in the file are already tokenized, so you can just split them at white spaces. You should convert each token to lower case. You should then pad each caption with a \<START\> token on the left and an \<END\> token on the right. 

For example, a single caption might look like this: 
['\<START\>',
  'a',
  'child',
  'in',
  'a',
  'pink',
  'dress',
  'is',
  'climbing',
  'up',
  'a',
  'set',
  'of',
  'stairs',
  'in',
  'an',
  'entry',
  'way',
  '.',
  '\<EOS\>'],

In [None]:
def read_image_descriptions(filename):    
    image_descriptions = {}
    
    with open(filename,'r') as in_file:
        pass # todo
    
    return image_descriptions

In [None]:
os.path.join(FLICKR_PATH, "Flickr8k.token.txt")

In [None]:
descriptions = read_image_descriptions(os.path.join(FLICKR_PATH, "Flickr8k.token.txt"))

In [None]:
descriptions['1000268201_693b08cb0e.jpg']

The previous line shoudl return 
<pre>[['<START>', 'a', 'child', 'in', 'a', 'pink', 'dress', 'is', 'climbing', 'up', 'a', 'set', 'of', 'stairs', 'in', 'an', 'entry', 'way', '.', '<EOS>'], ['<START>', 'a', 'girl', 'going', 'into', 'a', 'wooden', 'building', '.', '<EOS>'], ['<START>', 'a', 'little', 'girl', 'climbing', 'into', 'a', 'wooden', 'playhouse', '.', '<EOS>'], ['<START>', 'a', 'little', 'girl', 'climbing', 'the', 'stairs', 'to', 'her', 'playhouse', '.', '<EOS>'], ['<START>', 'a', 'little', 'girl', 'in', 'a', 'pink', 'dress', 'going', 'into', 'a', 'wooden', 'cabin', '.', '<EOS>']]</pre>

### Creating Word Indices

Next, we need to create a lookup table from the **training** data mapping words to integer indices, so we can encode input 
and output sequences using numeric representations. 

**TODO** create the dictionaries id_to_word and word_to_id, which should map tokens to numeric ids and numeric ids to tokens.  
Hint: Create a set of tokens in the training data first, then convert the set into a list and sort it. This way if you run the code multiple times, you will always get the same dictionaries. This is similar to the word indices you created for homework 3 and 4.  

Make sure you create word indices for the three special tokens `<PAD>`, `<START>`, and `<EOS>` (end of sentence).

In [None]:
id_to_word = {} #todo
id_to_word[0] = "<PAD>"
id_to_word[1] = "<START>"
id_to_word[2] = "<EOS>"
word_to_id = {} # todo

In [None]:
word_to_id['cat'] # should print an integer

In [None]:
id_to_word[1] # should print a token

Note that we do not need an UNK word token because we will only use the model as a generator, once trained.

## Part III Basic Decoder Model (24 pts)

For now, we will just train a model for text generation without conditioning the generator on the image input. 

We will use the LSTM implementation provided by PyTorch. The core idea here is that the recurrent layers (including LSTM) create an "unrolled" RNN. Each time-step is represented as a different position, but the weights for these positions are shared. We are going to use the constant MAX_LEN to refer to the maximum length of a sequence, which turns out to be 40 words in this data set (including START and END).

In [None]:
MAX_LEN = max(len(description) for image_id in train_list for description in descriptions[image_id])
MAX_LEN

In class, we discussed LSTM generators as transducers that map each word in the input sequence to the next word. 
<img src="http://www.cs.columbia.edu/~bauer/4705/lstm1.png" width="480px">

To train the model, we will convert each description into an input output pair as follows. For example, consider the sequence 

`['<START>', 'a', 'black', 'dog', '<EOS>']`

We would train the model using the following input/output pair (note both sequences are padded to the right up to MAX_LEN). That is, the output is simply the input shifted left (and with an extra <PAD> on the righ).

output| [`a`,`back`,`dog`,`<EOS>`,`<PAD>`,`<PAD>`,...]   |
------|--------------------------------------------------|
input | [`<START>`,`a`,`back`,`dog`,`<EOS>`,`<PAD>`,...] |

Here is the lange model in pytorch. We will choose input embeddings of dimensionality 512 (for simplicitly, we are not initializing these with pre-trained embeddings here). We will also use 512 for the hidden state vector and the output. 

In [None]:
from torch import nn 

vocab_size = len(word_to_id)+1
class GeneratorModel(nn.Module): 
    
    def __init__(self): 
        super(GeneratorModel, self).__init__()    
        self.embedding = nn.Embedding(vocab_size, 512) 
        self.lstm = nn.LSTM(512, 512, num_layers = 1, bidirectional=False, batch_first=True)
        self.output = nn.Linear(512,vocab_size)
        
    def forward(self, input_seq): 
        hidden = self.lstm(self.embedding(input_seq))
        out = self.output(hidden[0])
        return out

The input sequence is an integer tensor of size `[batch_size, MAX_LEN]`. Each row is a vector of size MAX_LEN in which each entry is an integer representing a word (according to the `word_to_id` dictionary). If the input sequence is shorter than MAX_LEN, the remaining entries should be padded with '<PAD>'.

For each input example, the model returns a distribution over possible output words. The model output is a tensor of size `[batch_size, MAX_LEN, vocab_size]`. vocab_size is the number of vocabulary words, i.e. len(word_to_id)

### Creating a Dataset for the text training data

**TODO**: Write a Dataset class for the text training data. The __getitem__ method should return an (input_encoding, output_encoding) pair for a single item. Both input_encoding and output_encoding should be tensors of size `[MAX_LEN]`, encoding the padded input/output sequence as illustrated above. 

I recommend to first read in all captions in the __init__ method and store them in a list. Above, we used the get_image_descriptions function to load the image descriptions into a dictionary. Iterate through the images in img_list, then access the corresponding captions in the `descriptions` dictionary. 

You can just refer back to the variables you have defined above, including `descriptions`, `train_list`, `vocab_size`, etc. 


In [None]:
MAX_LEN = 40

class CaptionDataset(Dataset):
    
    def __init__(self, img_list):
                
        self.data = []

        pass # TODO complete this method
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self,k):
        
        #TODO COMPLETE THIS METHOD

        input_enc = None #replace
        output_enc = None #replace
        return input_enc, output_enc

Let's instantiate the caption dataset and get the first item. You want to see something like this: 

for the input: 
<pre>
tensor([   1,   74,  805, 2312, 4015, 6488,  170,   74, 8686, 2312, 3922, 7922,
        7125,   17,    2,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0])
</pre>
for the output:
<pre>
    tensor([  74,  805, 2312, 4015, 6488,  170,   74, 8686, 2312, 3922, 7922, 7125,
          17,    2,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0])
</pre>

In [None]:
data = CaptionDataset(train_list)

In [None]:
i, o = data[0]
i

In [None]:
o

Let's try the model:

In [None]:
model = GeneratorModel().to(DEVICE)

In [None]:
model(i.to(DEVICE)).shape   # should return a [40, vocab_size]  tensor.

### Training the Model

The training function is identical to what you saw in homework 3 and 4.

In [None]:
from torch.nn import CrossEntropyLoss
loss_function = CrossEntropyLoss(ignore_index = 0, reduction='mean')

LEARNING_RATE = 1e-03
optimizer = torch.optim.AdamW(params=model.parameters(), lr=LEARNING_RATE)

loader = DataLoader(data, batch_size = 16, shuffle = True)

def train():
    """
    Train the model for one epoch.
    """
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    total_correct, total_predictions = 0, 0
    tr_preds, tr_labels = [], []
    # put model in training mode
    model.train()

    for idx, batch in enumerate(loader):

        inputs,targets = batch   
        inputs = inputs.to(DEVICE)
        targets = targets.to(DEVICE)
        # Run the forward pass of the model
        logits = model(inputs)
        loss = loss_function(logits.transpose(2,1), targets)
        tr_loss += loss.item()
        #print("Batch loss: ", loss.item()) # can comment out if too verbose.
        nb_tr_steps += 1
        nb_tr_examples += targets.size(0)

        # Calculate accuracy
        predictions = torch.argmax(logits, dim=2)  # Predicted token labels
        not_pads = targets != 0  # Mask for non-PAD tokens
        correct = torch.sum((predictions == targets) & not_pads)
        total_correct += correct.item()
        total_predictions += not_pads.sum().item()

        if idx % 100==0:
            #torch.cuda.empty_cache() # can help if you run into memory issues
            curr_avg_loss = tr_loss/nb_tr_steps
            print(f"Current average loss: {curr_avg_loss}")

        # Run the backward pass to update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute accuracy for this batch
        # matching = torch.sum(torch.argmax(logits,dim=2) == targets)
        # predictions = torch.sum(torch.where(targets==-100,0,1))

    epoch_loss = tr_loss / nb_tr_steps
    epoch_accuracy = total_correct / total_predictions if total_predictions != 0 else 0  # Avoid division by zero
    print(f"Training loss epoch: {epoch_loss}")
    print(f"Average accuracy epoch: {epoch_accuracy:.2f}")

Run the training until the accuracy reaches about 0.5 (this would be high for a language model on open-domain text, but the image caption dataset is comparatively small and closed-domain). This will take about 5 epochs.

In [None]:
train()

### Greedy Decoder

**TODO** Next, you will write a decoder. The decoder should start with the sequence `["<START>", "<PAD>","<PAD>"...]`, use the model to predict the most likely word in the next position. Append the word to the input sequence and then continue until `"<EOS>"` is predicted or the sequence reaches `MAX_LEN` words. 

In [None]:
def decoder():
    pass # TODO COMPLETE THIS METHHOD

In [None]:
decoder()

this will return something like 
['a',
 'man',
 'in',
 'a',
 'white',
 'shirt',
 'and',
 'a',
 'woman',
 'in',
 'a',
 'white',
 'dress',
 'walks',
 'by',
 'a',
 'small',
 'white',
 'building',
 '.',
 '<EOS>']

This simple decoder will of course always predict the same sequence (and it's not necessarily a good one). 

**TODO:** Modify the decoder as follows. Instead of choosing the most likely word in each step, sample the next word from the distribution (i.e. the softmax activated output) returned by the model. Make sure to apply torch.softmax() to convert the output activations into a distribution. 

To sample fromt he distribution, I recommend you take a look at [np.random.choice](https://numpy.org/doc/stable/reference/random/generated/numpy.random.choice.html), which takes the distribution as a parameter p.

In [None]:
import numpy as np 

def sample_decoder():
    pass # TODO COMPLETE THIS METHOD

for i in range(5):
    print(sample_decoder())

Some example outputs (it's stochastic, so your results will vary 

<pre>
['<START>', 'people', 'on', 'rocky', 'ground', 'swinging', 'basketball', '<EOS>']
['<START>', 'the', 'two', 'hikers', 'take', 'a', 'tandem', 'leap', 'while', 'another', 'is', 'involving', 'watching', '.', '<EOS>']
['<START>', 'a', 'man', 'attached', 'to', 'a', 'bicycle', 'rides', 'a', 'motorcycle', '.', '<EOS>']
['<START>', 'a', 'surfer', 'is', 'riding', 'a', 'wave', 'in', 'the', 'ocean', '.', '<EOS>']
['<START>', 'a', 'child', 'plays', 'in', 'a', 'round', 'fountain', '.', '<EOS>']
</pre>

You should now be able to see some interesting output that looks a lot like flickr8k image captions -- only that the captions are generated randomly without any image input. 

## Part III - Conditioning on the Image (24 pts)

We will now extend the model to condition the next word not only on the partial sequence, but also on the encoded image. 

We will concatenate the 512-dimensional image representation to each 512-dimensional token embedding. The LSTM will therefore see input representations of size 1024.

**TODO**: Write a new Dataset class for the combined image captioning data set. Each call to __getitem__ should return a triple  (image_encoding, input_encoding, output_encoding) for a single item. Both input_encoding and output_encoding should be tensors of size [MAX_LEN], encoding the padded input/output sequence as illustrated above. The image_encoding is the size [512] tensor we pre-computed in part I.

Note: One tricky issue here is that each image corresponds to 5 captions, so you have to find the correct image for each caption. You can create a mapping from image names to row indices in the image encoding tensor. This way you will be able to find each image by it's name. 

In [None]:
MAX_LEN = 40

class CaptionAndImage(Dataset):
    
    def __init__(self, img_list):

        self.img_data = torch.load(open("encoded_images_train.pt",'rb')) # suggested 
        self.img_name_to_id = dict([(i,j) for (j,i) in enumerate(img_list)])

        self.data = []
        # TODO COMPLETE THIS METHOD
                
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self,k):
        # TODO COMPLETE THIS METHOD
        img_data = None #replace
        input_enc = None #replace
        output_enc = None #replace
        
        return img_data, input_enc, output_enc

In [None]:
joint_data = CaptionAndImage(train_list)
img, i, o = data[0]
img.shape # should return torch.Size([512])

In [None]:
i.shape # should return torch.Size([40])

In [None]:
o.shape # should return torch.Size([40])

**TODO: Updating the model**
Update the language model code above to include a copy of the image for each position. 
The forward function of the new model should take two inputs: 
    
   1. a `(batch_size, 2048)` ndarray of image encodings. 
   2. a `(batch_size, MAX_LEN)` ndarray of partial input sequences. 
    
And one output as before: a `(batch_size, vocab_size)` ndarray of predicted word distributions.   

The LSTM will take input dimension 1024 instead of 512 (because we are concatenating the 512-dim image encoding). 

In the forward function, take the image and the embedded input sequence (i.e. AFTER the embedding was applied), and concatenate the image to each input. This requires some tensor manipulation. I recommend taking a look at [torch.Tensor.expand](https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html) and [torch.Tensor.cat](https://pytorch.org/docs/stable/generated/torch.Tensor.cat.html).



In [None]:
vocab_size = len(word_to_id)+1

class CaptionGeneratorModel(nn.Module): 
    
    def __init__(self): 
        super(CaptionGeneratorModel, self).__init__()    
        # TODO COMPLETE THIS METHOD
        
    def forward(self, img, input_seq): 

        # TODO COMPLETE THIS METHOD
        out = None # replace
        
        return out

Let's try this new model on one item: 

In [None]:
model = CaptionGeneratorModel().to(DEVICE)

In [None]:
item = joint_data[0]
img, input_seq, output_seq = item

In [None]:
logits = model(img.unsqueeze(0).to(DEVICE), input_seq.unsqueeze(0).to(DEVICE))

logits.shape # should return (1,40,8922) = (batch_size, MAX_LEN, vocab_size)

The training function is, again, mostly unchanged. Keep training until the accuracy exceeds 0.5.

In [None]:
from torch.nn import CrossEntropyLoss
loss_function = CrossEntropyLoss(ignore_index = 0, reduction='mean')

LEARNING_RATE = 1e-03
optimizer = torch.optim.AdamW(params=model.parameters(), lr=LEARNING_RATE)

loader = DataLoader(data, batch_size = 16, shuffle = True)

def train():
    """
    Train the model for one epoch.
    """
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    total_correct, total_predictions = 0, 0
    tr_preds, tr_labels = [], []
    # put model in training mode
    model.train()

    for idx, batch in enumerate(loader):
        
        img, inputs,targets = batch  
        img = img.to(DEVICE)
        inputs = inputs.to(DEVICE)
        targets = targets.to(DEVICE)
        
        # Run the forward pass of the model
        logits = model(img, inputs)
        loss = loss_function(logits.transpose(2,1), targets)
        tr_loss += loss.item()
        #print("Batch loss: ", loss.item()) # can comment out if too verbose.
        nb_tr_steps += 1
        nb_tr_examples += targets.size(0)

        # Calculate accuracy
        predictions = torch.argmax(logits, dim=2)  # Predicted token labels
        not_pads = targets != 0  # Mask for non-PAD tokens
        correct = torch.sum((predictions == targets) & not_pads)
        total_correct += correct.item()
        total_predictions += not_pads.sum().item()

        if idx % 100==0:
            #torch.cuda.empty_cache() # can help if you run into memory issues
            curr_avg_loss = tr_loss/nb_tr_steps
            print(f"Current average loss: {curr_avg_loss}")

        # Run the backward pass to update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute accuracy for this batch
        # matching = torch.sum(torch.argmax(logits,dim=2) == targets)
        # predictions = torch.sum(torch.where(targets==-100,0,1))

    epoch_loss = tr_loss / nb_tr_steps
    epoch_accuracy = total_correct / total_predictions if total_predictions != 0 else 0  # Avoid division by zero
    print(f"Training loss epoch: {epoch_loss}")
    print(f"Average accuracy epoch: {epoch_accuracy:.2f}")

In [None]:
train()

**TODO: Testing the model**: 
Rewrite the greedy decoder from above to take an encoded image representation as input.

In [None]:
def greedy_decoder(img):
    #TODO: Complete this method
    
    result = None  # replace
    return result

Now we can load one of the dev images, pass it through the preprocessor and the image encoder, and then into the decoder!

In [None]:
raw_img = PIL.Image.open(os.path.join(IMG_PATH, dev_list[199]))
preprocessed_img = preprocess(raw_img).to(DEVICE)
encoded_img = img_encoder(preprocessed_img.unsqueeze(0)).reshape((512))
caption = sample_decoder(encoded_img)
print(caption)
raw_img

The result should look pretty good for most images, but the model is prone to hallucinations. 

## Part IV - Beam Search Decoder (24 pts)

**TODO** Modify the simple greedy decoder for the caption generator to use beam search. 
Instead of always selecting the most probable word, use a *beam*, which contains the n highest-scoring sequences so far and their total probability (i.e. the product of all word probabilities). I recommend that you use a list of `(probability, sequence)` tuples. After each time-step, prune the list to include only the n most probable sequences. 

Then, for each sequence, compute the n most likely successor words. Append the word to produce n new sequences and compute their score. This way, you create a new list of n*n candidates. 

Prune this list to the best n as before and continue until `MAX_LEN` words have been generated. 

Note that you cannot use the occurence of the `"<EOS>"` tag to terminate generation, because the tag may occur in different positions for different entries in the beam. 

Once `MAX_LEN` has been reached, return the most likely sequence out of the current n. 

In [None]:
def img_beam_decoder(n, img):
   
    # TODO: Complete this method

**TODO** Finally, before you submit this assignment, please show 3 development images, each with 1) their greedy output, 2) beam search at n=3 3) beam search at n=5. 