Problem with this approach: The CLIP model was finetuned like a multiclass classification problem. The model presumes that each class has 1 representative in 1 batch. 

In [1]:
#Import packages
import os
import clip
import torch
import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import json
import cv2
from torchvision.transforms import ToTensor
import pandas as pd
from PIL import Image
import torch.nn as nn
import torch.optim

  Referenced from: <85A36C65-3F71-3C3B-B529-961AE17DBE73> /Users/szaboreka/anaconda3/lib/python3.11/site-packages/torchvision/image.so
  warn(


In [2]:
# Define device
if torch.cuda.is_available():
    device = torch.device("cuda") # use CUDA device
#elif torch.backends.mps.is_available():
#    device = torch.device("mps") # use MacOS GPU device (e.g., for M2 chips)
else:
    device = torch.device("cpu") # use CPU device
device

device(type='cpu')

In [3]:
#Load CLIP model - ViT B32
model, preprocess = clip.load('ViT-B/16', device, jit=False)

In [4]:
#Function to create a square-shaped image from the video (similar to 1 long image)
#To do: what if the video has more frames than 36?
def preprocess_video_to_image_grid_version(video_path, num_rows=6, num_cols=6):
    #Open the video file
    video = cv2.VideoCapture(video_path)
    #Create list for extracted frames
    frames = []
    #Handle if video can't be opened
    if not video.isOpened():
        print("Error: Could not open video file")
    else:
        while True:
            is_read, frame = video.read()
            if not is_read:
                break
            frames.append(frame)
        video.release()
    
    # Create  and store rows in the grids
    rows_list = []
    for i in range(num_rows):
        #create rows from the frames using indexes -- for example, if i=0, then between the 0th and 6th frame
        row = np.concatenate(frames[i * num_cols: (i + 1) * num_cols], axis=1)
        rows_list.append(row)
    
    # Concatenate grid vertically to create a single square-shaped image from the smoke video
    concatenated_frames = np.concatenate(rows_list, axis=0)
    return concatenated_frames

In [5]:
#Define Torch Dataset class
class ImageTitleDataset(Dataset):
    def __init__(self, list_video_path, list_labels, class_names):
        #Initalize image paths and corresponding texts
        self.video_path = list_video_path
        #Initialize labels (0 or 1)
        self.labels = list_labels
        #Initialize class names
        self.class_names = class_names
        #Transform to tensor
        #self.transforms = ToTensor()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        #tranform videos into images and preprocess with clip's function
        image = preprocess_video_to_image_grid_version(self.video_path[idx])
        image = Image.fromarray(image)
        image = preprocess(image)
        #get the corresponding class names and tokenize
        true_label = self.labels[idx]
        label = self.class_names[true_label]
        label = clip.tokenize(label, context_length=77, truncate=True)
        return image, label, true_label

In [6]:
#Define training data
# Load the JSON metadata
with open('data/datasets/experimental_ijmond_dataset.json', 'r') as f:
    data = json.load(f)
# Convert the dataset to a Pandas DataFrame
train_data = pd.DataFrame(data)
# Prepare the list of video file paths and labels
list_video_path = [os.path.join("data/ijmond_videos/", f"{fn}.mp4") for fn in train_data['file_name']]
#list_labels = dataset['label'].tolist()
list_labels = [int(label) for label in train_data['label']]
#Define class names in a list - it needs prompt engineering
class_names = ["a sequental photo of an industrial plant with clear sky above chimney, created from a video", "a sequental photo of an industrial plant emiting smoke from chimney, created from a video"]

In [25]:
#try labels and class names
print(class_names[1])
print(list_labels)
print(class_names[list_labels[0]])

a sequental photo of an industrial plant emiting smoke from chimney, created from a video
[1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0]
a sequental photo of an industrial plant emiting smoke from chimney, created from a video


In [7]:
# Create dataset and data loader for training
dataset = ImageTitleDataset(list_video_path, list_labels, class_names)
train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [27]:
# Function to convert model's parameters to FP32 format
#This is done so that our model loads in the provided memory
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

# Check if the device is set to CPU
if device == "cpu":
  model.float()

# Prepare the optimizer - weight from other user (https://www.labellerr.com/blog/fine-tuning-clip-on-custom-dataset/)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2) # the lr is smaller, more safe for fine tuning to new dataset

# Specify the loss functions - for images and for texts
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()

In [28]:
# Model training
num_epochs = 5
for epoch in range(num_epochs):
  #model.train(True)
  pbar = tqdm(train_dataloader, total=len(train_dataloader))
  for batch in pbar:
      # Zero out gradients for the optimizer (Adam)
      optimizer.zero_grad()

      # Extract images and texts from the batch
      images,texts, true_label = batch 
      print('Texts: ', texts) #The texts in the batch are tokenized class names

      # Print the current device (CPU or GPU)
      print("Used device: ", device)

      # Move images and texts to the specified device (CPU or GPU)
      images= images.to(device)
      texts = texts.to(device)

      #Squeeze texts tensor to match the required size
      texts = texts.squeeze(dim = 1)
      #print("Shape of input tensor before forward pass: ", texts.shape)
      #images = torch.stack([img for img in images],dim=0)

      # Forward pass
      logits_per_image, logits_per_text = model(images, texts)
      print('Logits_per_text after forward passing: ', logits_per_text)

      # Compute loss
      ground_truth = torch.tensor(true_label, dtype=torch.long, device=device)
      #ground_truth = torch.tensor(texts[batch], dtype=torch.long, device=device)
      #ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
      print('Ground truth: ', ground_truth)

      #Transform logits to float to match required dtype
      logits_per_image = logits_per_image.float()
      logits_per_text = logits_per_text.float()

      total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2

      # Backward pass
      total_loss.backward()
      if device == "cpu":
         optimizer.step()
      else : 
        # Convert model's parameters to FP32 format, update, and convert back
        convert_models_to_fp32(model)
        optimizer.step()
        clip.model.convert_weights(model)
      # Update the progress bar with the current epoch and loss
      pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss.item():.4f}")
  

  0%|          | 0/7 [00:00<?, ?it/s]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

  ground_truth = torch.tensor(true_label, dtype=torch.long, device=device)
Epoch 0/5, Loss: 2.1630:  14%|█▍        | 1/7 [00:04<00:29,  4.93s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 0/5, Loss: 8.4932:  29%|██▊       | 2/7 [05:04<14:51, 178.26s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 0/5, Loss: 1.1301:  43%|████▎     | 3/7 [10:07<15:40, 235.12s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 0/5, Loss: 2.3338:  57%|█████▋    | 4/7 [15:08<13:04, 261.37s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 0/5, Loss: 1.5497:  71%|███████▏  | 5/7 [20:06<09:09, 274.50s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 0/5, Loss: 1.5761:  86%|████████▌ | 6/7 [25:11<04:44, 284.85s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 0/5, Loss: 0.7273: 100%|██████████| 7/7 [27:41<00:00, 237.35s/it]
  0%|          | 0/7 [00:00<?, ?it/s]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 1/5, Loss: 1.8852:  14%|█▍        | 1/7 [05:01<30:09, 301.62s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 1/5, Loss: 1.3571:  29%|██▊       | 2/7 [10:10<25:28, 305.65s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 1/5, Loss: 1.3311:  43%|████▎     | 3/7 [15:13<20:17, 304.43s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 1/5, Loss: 1.6778:  57%|█████▋    | 4/7 [20:12<15:07, 302.44s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 1/5, Loss: 1.4892:  71%|███████▏  | 5/7 [55:06<31:37, 948.60s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 1/5, Loss: 1.5020:  86%|████████▌ | 6/7 [2:36:40<44:58, 2698.00s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 1/5, Loss: 0.7422: 100%|██████████| 7/7 [2:39:07<00:00, 1363.96s/it]
  0%|          | 0/7 [00:00<?, ?it/s]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 2/5, Loss: 1.4896:  14%|█▍        | 1/7 [04:48<28:50, 288.40s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 2/5, Loss: 1.3878:  29%|██▊       | 2/7 [09:37<24:03, 288.60s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 2/5, Loss: 1.4284:  43%|████▎     | 3/7 [14:26<19:15, 288.75s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 2/5, Loss: 1.3577:  57%|█████▋    | 4/7 [19:17<14:29, 289.94s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 2/5, Loss: 1.3691:  71%|███████▏  | 5/7 [24:10<09:41, 290.95s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 2/5, Loss: 1.3591:  86%|████████▌ | 6/7 [1:49:42<32:16, 1936.94s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 2/5, Loss: 0.6574: 100%|██████████| 7/7 [1:52:16<00:00, 962.34s/it] 
  0%|          | 0/7 [00:00<?, ?it/s]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 3/5, Loss: 1.3863:  14%|█▍        | 1/7 [05:02<30:16, 302.82s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 3/5, Loss: 1.3677:  29%|██▊       | 2/7 [10:06<25:18, 303.61s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 3/5, Loss: 1.4952:  43%|████▎     | 3/7 [15:00<19:55, 298.80s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 3/5, Loss: 1.4718:  57%|█████▋    | 4/7 [19:59<14:57, 299.18s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 3/5, Loss: 1.3475:  71%|███████▏  | 5/7 [25:03<10:01, 300.85s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 3/5, Loss: 1.4079:  86%|████████▌ | 6/7 [30:06<05:01, 301.58s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 3/5, Loss: 0.6921: 100%|██████████| 7/7 [32:31<00:00, 278.81s/it]
  0%|          | 0/7 [00:00<?, ?it/s]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 4/5, Loss: 1.4220:  14%|█▍        | 1/7 [04:51<29:08, 291.43s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 4/5, Loss: 1.3372:  29%|██▊       | 2/7 [09:43<24:19, 291.89s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 4/5, Loss: 1.3532:  43%|████▎     | 3/7 [14:32<19:22, 290.68s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 4/5, Loss: 1.5983:  57%|█████▋    | 4/7 [19:44<14:57, 299.08s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   908,
           1257,  6664,   633, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 4/5, Loss: 1.3878:  71%|███████▏  | 5/7 [25:10<10:17, 308.53s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 4/5, Loss: 1.3624:  86%|████████▌ | 6/7 [58:48<14:50, 890.03s/it]

Texts:  tensor([[[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  5491,  6168,  1125,   539,   550,  7520,  3912,   593,
           3143,  2390,  4348, 26821,   267,  4080,   633,   320,  1455, 49407,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0

Epoch 4/5, Loss: 0.7380: 100%|██████████| 7/7 [1:01:22<00:00, 526.10s/it]


Code to examine the dataset object:

In [30]:
#Inspect a few examples in dataset

# Create dataset
dataset = ImageTitleDataset(list_video_path, list_labels)
print("Dataset Length:", len(dataset))

# Inspect 3 samples
for i in range(3):
    image, label = dataset[i]
    print("Sample:", i)
    print("Image Shape:", image.shape)
    print("Label:", label)
    

Dataset Length: 26
Sample: 0
Image Shape: torch.Size([3, 224, 224])
Label: tensor([[49406,   320,  1125,   539,  7520,  5829,   908,  1257,  6664,   633,
         26821, 49407,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]], dtype=torch.int32)
Sample: 1
Image Shape: torch.Size([3, 224, 224])
Label: tensor([[49406,   320,  1125,   539,  7520,  5829,   593,   871,  6664,  4348,
         26821, 49407,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,


In [15]:
#Inspect Batch sizes
# Create DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Iterate over a few batches
for images, labels in dataloader:
    print("Batch Images Shape:", images.shape)
    print("Batch Labels:", labels)
    break  # Stop after first batch

# (batch_size, channel, time, height, width)

Batch Images Shape: torch.Size([4, 3, 224, 224])
Batch Labels: tensor([[[49406,   320,  1125,   539,  7520,  5829,   908,  1257,  6664,   633,
          26821, 49407,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]],

        [[49406,   320,  1125,   539,  7520,  5829,   593,   871,  6664,  4348,
          26821, 49407,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,   

Function to create 1 long image

In [4]:
#Function to create one long image from video frames
def preprocess_video_to_image(video_path):
    # Open the video file
    video = cv2.VideoCapture(video_path)
    frames = []
    #Handle if video can't be opened
    if not video.isOpened():
        print("Video file couldn't be opened")
    #If yes, read all video frames until the end of the video and append every frame to the frames list
    else:
        while True:
            ret, frame = video.read()
            if not ret:
                break
            frames.append(frame)
        #Release video
        video.release()
    #Concetanate the frames in the list together
    concatenated_frames = np.concatenate(frames, axis=1)
    return concatenated_frames

Code to save model and open the saved model

In [None]:
#code to save the trained model
torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': total_loss,
        }, f"model_checkpoint/model_clip_1.pt") #just change to your preferred folder/filename

In [None]:
#Code to load the saved model :
model, preprocess = clip.load("ViT-B/32",device=device,jit=False) #Must set jit=False for training
checkpoint = torch.load("model_checkpoint/model_clip_1.pt")

# Use these 3 lines if you use default model setting(not training setting) of the clip. For example, if you set context_length to 100 since your string is very long during training, then assign 100 to checkpoint['model_state_dict']["context_length"] 
checkpoint['model_state_dict']["input_resolution"] = model.input_resolution #default is 224
checkpoint['model_state_dict']["context_length"] = model.context_length # default is 77
checkpoint['model_state_dict']["vocab_size"] = model.vocab_size 

model.load_state_dict(checkpoint['model_state_dict'])