In [None]:
rm -rf /content/dnnls_final_project


In [None]:
%cd /content


In [None]:

!git clone https://github.com/scrannah/dnnls_final_project.git
%cd /content/dnnls_final_project


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
# @title Importing the necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F

from bs4 import BeautifulSoup
import re
import matplotlib.pyplot as plt
import numpy as np

from google.colab import drive
import os

from datasets import load_dataset

from datasets.fingerprint import random
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torchvision.transforms.functional as FT
import torch.nn.functional as F
from torchmetrics.image import StructuralSimilarityIndexMeasure
from nltk.translate.bleu_score import sentence_bleu

from transformers import BertTokenizer
import gc

import textwrap
import sys
import os

abspath = r".\dnnls_final_project"
sys.path.append(os.path.abspath(abspath ))

# put src imports here models training utils etc
from src.attention.attention import Attention

from src.dataloaders.sp_dataset import SequencePredictionDataset
from src.dataloaders.tag_extraction import parse_gdi_text
from src.dataloaders.text_dataset import TextTaskDataset
from src.dataloaders.vae_dataset import AutoEncoderTaskDataset

# from src.encoders.visual_autoencoder import Backbone, VisualEncoder, VisualDecoder, VisualAutoencoder
from src.encoders.text_autoencoder import EncoderLSTM, DecoderLSTM, Seq2SeqLSTM
from src.encoders.perceptual_loss import PerceptualLoss
# from src.encoders.unetvisual_autoencoder import UNetBackbone, UNetVisualEncoder, UNetVisualDecoder, UNetVisualAutoencoder
# from src.encoders.VAEvisual_autoencoder import VAEBackbone, VAEVisualEncoder, VAEVisualDecoder, VAEVisualAutoencoder
# from src.encoders.Newvisual_autoencoder import NewBackbone, NewVisualEncoder, NewVisualDecoder, NewVisualAutoencoder
from src.encoders.resnet18_visualautoencoder import ResNet18Backbone, NewVisualEncoder, NewVisualDecoder, NewVisualAutoencoder
from src.encoders.gradient_loss import sobel_gradients, sobel_gradient_loss

# from src.model.sequence_predictor import SequencePredictor
from src.model.cmsequence_predictor import CMSequencePredictor

from src.train.train_sequence_predictor import train_sequence_predictor
from src.train.train_visual_autoencoder import train_visual_autoencoder

from src.utils.checkpoints import save_checkpoint_to_drive, load_checkpoint_from_drive
from src.utils.training_utils import init_weights, validation, show_image
from src.utils.token_generate import generate


In [None]:
# @title Variables and initial setup
torch.cuda.empty_cache()
gc.collect()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

N_EPOCHS = 10
emb_dim = 16
latent_dim = 16 # Changed from 256 to 16 to match the checkpoint
num_layers = 1
dropout = True

In [None]:
# @title Dataset loading/creation

# Loading the dataset
train_dataset = load_dataset("daniel3303/StoryReasoning", split="train")
test_dataset = load_dataset("daniel3303/StoryReasoning", split="test")

# Split the training dataset into training and validation sets
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_indices, val_indices = random_split(
    range(len(train_dataset)),
    [train_size, val_size]) # split the database universally here, avoids data leakage
    # for a range of the len of train set, split into these indices

# For the Sequence prediction task
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased",  padding=True, truncation=True)
sp_train_dataset = SequencePredictionDataset(train_dataset, tokenizer) # Instantiate the train dataset
sp_test_dataset = SequencePredictionDataset(test_dataset, tokenizer) # Instantiate the test dataset

# Instantiate the dataloaders
sp_train_subset = Subset(sp_train_dataset, train_indices)
train_dataloader = DataLoader(sp_train_subset, batch_size=8, shuffle=True)
# We will use the validation set to visualize the progress.
sp_val_subset = Subset(sp_train_dataset, val_indices)
val_dataloader = DataLoader(sp_val_subset, batch_size=4, shuffle=True)

test_dataloader = DataLoader(sp_test_dataset, batch_size=4, shuffle=True)

# Text dataset
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased",  padding=True, truncation=True)
text_dataset = TextTaskDataset(train_dataset)
text_test_dataset = TextTaskDataset(test_dataset)

# Universal split applied HERE (use the same indices)
text_train_subset = Subset(text_dataset, train_indices)
text_val_subset   = Subset(text_dataset, val_indices)

text_train_dataloader = DataLoader(text_train_subset, batch_size=4, shuffle=True)
text_val_dataloader   = DataLoader(text_val_subset, batch_size=4, shuffle=False)
text_test_dataloader  = DataLoader(text_test_dataset, batch_size=4, shuffle=False)

# Image dataset
autoencoder_train_dataset = AutoEncoderTaskDataset(train_dataset)
autoencoder_test_dataset = AutoEncoderTaskDataset(test_dataset)
ae_train_subset = Subset(autoencoder_train_dataset, train_indices)
ae_val_subset   = Subset(autoencoder_train_dataset, val_indices)

autoencoder_train_dataloader = DataLoader(ae_train_subset, batch_size=4, shuffle=True,
                                    collate_fn=lambda batch: torch.stack(batch, dim=0))
autoencoder_test_dataloader = DataLoader(autoencoder_test_dataset, batch_size=4, shuffle=False,
                                    collate_fn=lambda batch: torch.stack(batch, dim=0))
autoencoder_val_dataloader = DataLoader(ae_val_subset, batch_size=4, shuffle=False,
                                    collate_fn=lambda batch: torch.stack(batch, dim=0))

In [None]:
from src.encoders.perceptual_loss import PerceptualLoss
# @title Initialising models and setup

# Initializing the NLP models
encoder = EncoderLSTM(tokenizer.vocab_size, emb_dim, latent_dim, num_layers, dropout).to(device)
decoder = DecoderLSTM(tokenizer.vocab_size, emb_dim, latent_dim, num_layers, dropout).to(device)
text_autoencoder = Seq2SeqLSTM(encoder, decoder).to(device)
text_autoencoder, _, epoch_loaded, loss_loaded = load_checkpoint_from_drive(text_autoencoder, None, filename='text_autoencoder.pth') # remember to load correct one

total_params = sum(p.numel() for p in text_autoencoder.parameters())
print(f"Total parameters (Not trainable): {total_params}")
print(f"Text Autoencoder loaded from epoch {epoch_loaded} with loss {loss_loaded:.4f}")
# Deactivating training from this model for efficiency
#for param in text_autoencoder.parameters():
    #param.requires_grad = False


# Initializing visual models
# visual_autoencoder = VisualAutoencoder(latent_dim=latent_dim).to(device)
visual_autoencoder = VisualAutoencoder(latent_dim=latent_dim).to(device)
visual_autoencoder.apply(init_weights)

# Load checkpoint WITH optimizer if continuing training exactly as before
#comment out when freezing ae
visual_ae_optimizer = torch.optim.Adam(
    (p for p in visual_autoencoder.parameters() if p.requires_grad),
    lr=1e-2
)


visual_autoencoder, _, epoch_loaded, loss_loaded = load_checkpoint_from_drive(
  visual_autoencoder,   None,
    filename="Newvisual_autoencoder128244244resnet.pth"
)

for p in visual_autoencoder.parameters(): # freeze for sequence prediction tasks
    p.requires_grad = False

total_params = sum(p.numel() for p in visual_autoencoder.parameters() if p.requires_grad)
print(f"Total trainable parameters in visual autoencoder: {total_params}")
#print(f"Visual Autoencoder loaded from epoch {epoch_loaded} with loss {loss_loaded:.4f}")


sequence_predictor = CMSequencePredictor(
    visual_autoencoder,
    text_autoencoder,
    latent_dim,latent_dim).to(device)

sequence_predictor, _, epoch_loaded, loss_loaded = load_checkpoint_from_drive(
   sequence_predictor,   None,
    filename="sequence_predictorFINAL.pth"
#)
total_params = sum(p.numel() for p in sequence_predictor.parameters())
print(f"Total parameters: {total_params}")
#print(f"Sequence Prediction loaded from epoch {epoch_loaded} with loss {loss_loaded:.4f}")



# Training tools
criterion_images = nn.L1Loss()
criterion_ctx = nn.MSELoss()
criterion_text = nn.CrossEntropyLoss(ignore_index=tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
criterion_percep = PerceptualLoss(device).to(device)

# Optimizer for SEQUENCE PREDICTOR created AFTER model definition
optimizer = torch.optim.Adam(
    (p for p in sequence_predictor.parameters() if p.requires_grad),
    lr=0.001
)



In [None]:
# @title Sequence predictor training sequence
losses = []
all_epoch_losses = []
all_train_mse = []
all_train_perplexity = []
all_train_bleu = []
all_train_crossmodal = []
all_train_ssim = []

all_val_mse = []
all_val_perplexity = []
all_val_bleu = []
all_val_crossmodal = []
all_val_ssim = []

lambda_cm = 0.01
N_EPOCHS = 21

for epoch in range(N_EPOCHS):

    metrics = train_sequence_predictor(
        model=sequence_predictor,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        optimizer=optimizer,
        criterion_images=criterion_images,
        criterion_ctx=criterion_ctx,
        criterion_text=criterion_text,
        tokenizer=tokenizer,
        device=device,
        num_epochs=1,
        lambda_cm=lambda_cm
    )

    all_epoch_losses.append(metrics["epoch_losses"][0])

    all_train_mse.append(metrics["train_mse"][0])
    all_train_perplexity.append(metrics["train_perplexity"][0])
    all_train_bleu.append(metrics["train_bleu"][0])
    all_train_crossmodal.append(metrics["train_crossmodal"][0])
    all_train_ssim.append(metrics["train_ssim"][0])

    all_val_mse.append(metrics["val_mse"][0])
    all_val_perplexity.append(metrics["val_perplexity"][0])
    all_val_bleu.append(metrics["val_bleu"][0])
    all_val_crossmodal.append(metrics["val_crossmodal"][0])
    all_val_ssim.append(metrics["val_ssim"][0])

    if epoch % 5 == 0:
        save_checkpoint_to_drive(
            sequence_predictor,
            optimizer,
            epoch,
            all_epoch_losses[-1],
            filename=f"sequence_predictor.pth"
        )


In [None]:
N_EPOCHS = 21
lambda_cm = 0.01
for epoch in range(N_EPOCHS):
  with torch.no_grad():
    sequence_predictor.eval()
    running_loss = 0.0 # This line was over-indented
    for frames, descriptions, image_target, text_target in test_dataloader:
        # Send images and tokens to the GPU
        descriptions = descriptions.to(device)
        frames = frames.to(device)
        image_target = image_target.to(device)
        text_target = text_target.to(device)
        # Predictions from our model
        pred_image_content, pred_image_context, predicted_text_logits_k, _, _, z_t_flat, z_v_flat = sequence_predictor(frames,
                                                                                                        descriptions,
                                                                                                        text_target)
        # Computing losses
        # Loss for image reconstruction
        loss_im = criterion_images(pred_image_content, image_target)  # image loss
        # Loss for the average pattern the images contain
        mu_global = frames.mean(dim=[0, 1])
        mu_global = mu_global.unsqueeze(0).expand_as(pred_image_context)
        loss_context = criterion_ctx(pred_image_context, mu_global)  # context loss
        # Loss function for the text prediction
        prediction_flat = predicted_text_logits_k.reshape(-1, tokenizer.vocab_size)
        target_labels = text_target.squeeze(1)[:, 1:]  # Slice to get [8, 119]
        target_flat = target_labels.reshape(-1)
        loss_text = criterion_text(prediction_flat, target_flat)
        loss_align = 1 - F.cosine_similarity(z_v_flat, z_t_flat, dim=1).mean()
        # Combining the losses
        loss = loss_im + loss_text + 0.2 * loss_context + lambda_cm * loss_align

        running_loss += loss.item() * frames.size(0)

    epoch_loss = running_loss / len(test_dataloader.dataset)
    print(f"Epoch {epoch+1} Test Combined Loss: {epoch_loss:.4f}")

    test_metrics = validation(
        model=sequence_predictor,
        data_loader=test_dataloader,
        device= device,
        tokenizer=tokenizer,
        criterion_text=criterion_text,
        criterion_ctx=criterion_ctx
    )


In [None]:
# @title Computing and showing average images
N = 1000
H, W = 60, 125

# Tensors to accumulate sum (for mean) and sum of squares (for variance)
avg_images = [torch.zeros((3, H, W)) for _ in range(5)]
sum_sq_diff = [torch.zeros((3, H, W)) for _ in range(5)] # Placeholder for variance numerator
transform = transforms.Compose([
  transforms.Resize((H, W)), # Add this line to resize images
  transforms.ToTensor()
])

# --- First Pass: Calculate the Sum (for Mean) ---
print("Starting Pass 1: Calculating Mean...")

for i in range(N):
    # Process sequence i
    sequence = train_dataset[i]["images"]

    for j in range(5):
        image = transform(sequence[j])
        avg_images[j] += image # Sum for mean

# Final step for mean
for j in range(5):
    avg_images[j] /= N

print("Starting Pass 2: Calculating Variance...")

for i in range(N):
    # Process sequence i
    sequence = train_dataset[i]["images"]

    for j in range(5):
        image = transform(sequence[j])

        # Calculate (Image - Mean)^2
        # Note: We detach the mean from the computation graph if it were being trained,
        # but here we're just using it as a fixed statistical value.
        diff = image - avg_images[j]
        sum_sq_diff[j] += diff * diff # Element-wise squaring

# --- Final step for Standard Deviation ---
std_images = []
for j in range(5):
    # Variance = Sum of Squared Differences / N
    variance = sum_sq_diff[j] / N

    # Standard Deviation = sqrt(Variance)
    std_dev = torch.sqrt(variance)
    std_images.append(std_dev)

print("Computation Complete. std_images is a list of 5 tensors (3x60x125).")
# You now have the 5 tensors you need for normalization (mean and std).

fig, ax = plt.subplots(1,5, figsize=(15,5))
for i in range(5):
  avg_image = avg_images[i]

  # Printing range of avg_image
  print(torch.min(avg_image), torch.max(avg_image))

  avg_imagen = (avg_image - torch.min(avg_image))/(torch.max(avg_image) - torch.min(avg_image))
  show_image(ax[i], avg_imagen)

# Create a matrix of images with the differences between avg_images
fig, ax = plt.subplots(5,5, figsize=(15,8))

for i in range(5):
  for j in range(5):
    if i == j:
      avg_image = avg_images[i]
      avg_imagen = (avg_image - torch.min(avg_image))/(torch.max(avg_image) - torch.min(avg_image))
      show_image(ax[i,j], avg_imagen)
    else:
      diff = avg_images[i] - avg_images[j]
      diff = (diff - torch.min(diff))/(torch.max(diff) - torch.min(diff))
      show_image(ax[i,j], diff)
    ax[i,j].set_xticks([])
    ax[i,j].set_yticks([])
plt.tight_layout()
plt.subplots_adjust(
    wspace=0, # Set horizontal space to zero
    hspace=0  # Set vertical space to zero
)


In [None]:
# @title Image autoencoder training
# batches_per_epoch = len(autoencoder_train_dataloader)

# beta = 1.0
# kl_anneal_epoch = batches_per_epoch * 20
lambda_percep = 0.03
lambda_ctx = 0.2
lambda_grad = 0.005
N_EPOCHS = 26
# global_step = 0 # this did solve epoch counting but it will break training loop
fixed_batch = next(iter(autoencoder_train_dataloader))
fixed_batch = fixed_batch.to(device)
for epoch in range(N_EPOCHS):

    epoch_losses = train_visual_autoencoder(
        model=visual_autoencoder,
        train_dataloader=autoencoder_train_dataloader,
        optimizer=visual_ae_optimizer,
        criterion_images=criterion_images,
        criterion_percep=criterion_percep,
        criterion_ctx=criterion_ctx,
        lambda_percep=lambda_percep,
        lambda_ctx=lambda_ctx,
        lambda_grad=lambda_grad,
        device=device
    )

    if epoch % 5 == 0:
        save_checkpoint_to_drive(
            visual_autoencoder,
            visual_ae_optimizer,
            epoch,
            epoch_losses[-1],
            filename=f"Newvisual_autoencoder128244244resnet.pth"
        )


    if epoch % 1 == 0:   # after every epoch
        with torch.no_grad():
            visual_autoencoder.eval()

            sample_batch = next(iter(autoencoder_train_dataloader))
            sample_batch = sample_batch.to(device)

            x_hat = visual_autoencoder(sample_batch)


            x_content, x_context = x_hat  # assuming tuple
            # print("x_content:", x_content.shape)   # want (B, 3, H, W)
            # print("x_context:", x_context.shape)
            z = visual_autoencoder.encoder(fixed_batch)
            fixedx_content, fixedx_context = visual_autoencoder(fixed_batch)



            print("img mean/std:",
                  fixed_batch.mean().item(), fixed_batch.std().item(),
                  "min/max:",
                  fixed_batch.min().item(), fixed_batch.max().item())

            print("out mean/std:",
                  fixedx_content.mean().item(), fixedx_content.std().item(),
                  "min/max:",
                  fixedx_content.min().item(), fixedx_content.max().item())

            print("z mean/std:",
                  z.mean().item(), z.std().item())


            img_mean = torch.tensor([0.485, 0.456, 0.406])   #image norms if i need them
            img_std  = torch.tensor([0.229, 0.224, 0.225])

            fig, axs = plt.subplots(1, 3, figsize=(12, 4))


            show_image(axs[0], sample_batch[0].cpu(), de_normalize=False, img_mean=img_mean, img_std=img_std)
            axs[0].set_title("Original")

            show_image(axs[1], x_hat[0][0].cpu(), de_normalize=False, img_mean=img_mean, img_std=img_std)
            axs[1].set_title("Content")

            show_image(axs[2], x_hat[0][1].cpu(), de_normalize=False, img_mean=img_mean, img_std=img_std)
            axs[2].set_title("Context")



            for a in axs:
                a.axis("off")

            plt.show()



In [None]:
# @title SP METRIC PLOTTING

# Training Loss Curve
plt.plot(all_epoch_losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Testing Loss Over Time")
plt.show()

# Training SSIM
plt.plot(all_test_ssim)
plt.xlabel("Epoch")
plt.ylabel("SSIM")
plt.title("Testing SSIM Over Epochs")
plt.show()

# Training BLEU
plt.plot(all_test_bleu)
plt.xlabel("Epoch")
plt.ylabel("BLEU")
plt.title("Testing BLEU Over Epochs")
plt.show()

# Training L1
plt.plot(all_test_mse)
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.title("Testing MSE Over Epochs")
plt.show()

# Training Perplexity
plt.plot(all_test_perplexity)
plt.xlabel("Epoch")
plt.ylabel("Perplexity")
plt.title("Testing Perplexity Over Epochs")
plt.show()

# Training Cross-Modal Similarity
plt.plot(all_test_crossmodal)
plt.xlabel("Epoch")
plt.ylabel("Cross-modal similarity")
plt.title("Testing Cross-modal Similarity Over Epochs")
plt.show()

# Validation SSIM
plt.plot(all_val_ssim)
plt.xlabel("Epoch")
plt.ylabel("SSIM")
plt.title("Validation SSIM Over Epochs")
plt.show()

# Validation BLEU
plt.plot(all_val_bleu)
plt.xlabel("Epoch")
plt.ylabel("BLEU")
plt.title("Validation BLEU Over Epochs")
plt.show()

# Validation L1
plt.plot(all_val_mse)
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.title("Validation MSE Over Epochs")
plt.show()

# Validation Perplexity
plt.plot(all_val_perplexity)
plt.xlabel("Epoch")
plt.ylabel("Perplexity")
plt.title("Validation Perplexity Over Epochs")
plt.show()

# Validation Cross-Modal Similarity
plt.plot(all_val_crossmodal)
plt.xlabel("Epoch")
plt.ylabel("Cross-modal similarity")
plt.title("Validation Cross-modal Similarity Over Epochs")
plt.show()
