In [3]:
!pip install -r requirements.txt

Collecting scikit-learn (from -r requirements.txt (line 6))
  Downloading scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Collecting wandb (from -r requirements.txt (line 8))
  Downloading wandb-0.22.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting scipy>=1.8.0 (from scikit-learn->-r requirements.txt (line 6))
  Downloading scipy-1.16.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (62 kB)
Collecting joblib>=1.2.0 (from scikit-learn->-r requirements.txt (line 6))
  Downloading joblib-1.5.2-py3-none-any.whl.metadata (5.6 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn->-r requirements.txt (line 6))
  Downloading threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)
Collecting click>=8.0.1 (from wandb->-r requirements.txt (line 8))
  Downloading click-8.3.0-py3-none-any.whl.metadata (2.6 kB)
Collecting gitpython!=3.1.29,>=1.0.0 (from wandb->-r requirements.txt (line 8))
  Down

In [1]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class EEGMultimodalDataset(Dataset):
    """
    PyTorch Dataset class for loading the multimodal EEG, Image, and Text data.
    
    This version contains the fix for returning the 'category' for classification.
    """
    def __init__(self, 
                 bids_root,          # Path to the .../ds005589/ directory
                 images_dir,         # Path to the .../All_images/ directory
                 captions_path,      # Path to the captions.txt file
                 subject_list,       # List of subjects to load, e.g., ['sub-02', 'sub-03']
                 session_list,       # List of sessions to load, e.g., ['ses-01', 'ses-02']
                 image_transform=None, # PyTorch transforms for the images
                 clamp_thres=500     # Clamping threshold for EEG in microvolts
                ):
        
        self.bids_root = bids_root
        self.images_dir = images_dir
        self.image_transform = image_transform
        self.clamp_thres = clamp_thres

        # 1. Initialize all 4 lists
        self.all_eeg_trials = []
        self.all_image_paths = []
        self.all_captions = []
        self.all_categories = [] # <-- For categories
        
        print("Initializing dataset... This may take a moment.")
        
        print(f"Loading captions from {captions_path}...")
        self.captions_dict = self._load_captions(captions_path)
        print(f"Loaded {len(self.captions_dict)} captions.")

        for sub in subject_list:
            for ses in session_list:
                for run in ['01', '02', '03', '04']:
                    
                    session_path = os.path.join(self.bids_root, sub, ses)
                    csv_path = os.path.join(session_path, f"{sub}_{ses}_task-lowSpeed_run-{run}_image.csv")
                    npy_path = os.path.join(session_path, f"{sub}_{ses}_task-lowSpeed_run-{run}_1000Hz.npy")
                    
                    if not (os.path.exists(csv_path) and os.path.exists(npy_path)):
                        # print(f"Warning: Missing files for {sub} {ses} {run}. Skipping.")
                        continue
                        
                    try:
                        csv_data = pd.read_csv(csv_path) 
                    except Exception as e:
                        print(f"Error reading CSV {csv_path}: {e}. Skipping run.")
                        continue
                    
                    eeg_data = np.load(npy_path) 
                    
                    if eeg_data.shape[0] != len(csv_data):
                        print(f"Warning: Trial mismatch in {sub} {ses} {run}. "
                              f"EEG has {eeg_data.shape[0]}, CSV has {len(csv_data)}. Skipping.")
                        continue
                        
                    for i, row in csv_data.iterrows():
                        
                        img_base_name = self._get_base_name(row['FilePath']) 
                        if not img_base_name:
                            continue
                        
                        category, caption = self.captions_dict.get(img_base_name, ("Unknown", "No Caption"))
                        
                        img_path = self._find_image_path(img_base_name)
                        if not img_path:
                            continue 
                            
                        # 2. Append all 4 items in the loop
                        self.all_eeg_trials.append(eeg_data[i])   
                        self.all_image_paths.append(img_path)     
                        self.all_captions.append(caption)         
                        self.all_categories.append(category) # <-- *** THIS WAS THE FIX ***

        print(f"Found {len(self.all_eeg_trials)} total aligned trials.")
        
        if len(self.all_eeg_trials) == 0:
            print("ERROR: No trials were loaded. Check your BIDS_ROOT, IMAGE_DIR, and CAPTIONS_FILE paths.")
            self.eeg_dataset = np.array([])
            self.image_paths = []
            self.captions = []
            self.categories = [] # 3. Store the (empty) list
            return

        eeg_dataset = np.array(self.all_eeg_trials, dtype=np.float32)
        
        # Clamp
        eeg_dataset[eeg_dataset >  self.clamp_thres] =  self.clamp_thres
        eeg_dataset[eeg_dataset < -self.clamp_thres] = -self.clamp_thres
        
        # Normalize
        sample_num, time_num, channel_num = eeg_dataset.shape # <-- Corrected shape
        eeg_dataset_flat = eeg_dataset.reshape(sample_num, -1)
        
        mean = np.mean(eeg_dataset_flat, axis=0)
        std = np.std(eeg_dataset_flat, axis=0)
        
        eeg_dataset_flat = (eeg_dataset_flat - mean) / (std + 1e-6)
        
        self.eeg_dataset = eeg_dataset_flat.reshape(sample_num, time_num, channel_num)
        
        # 3. Store all 4 lists
        self.image_paths = self.all_image_paths
        self.captions = self.all_captions
        self.categories = self.all_categories
        
        print("Dataset initialization complete.")

    def _load_captions(self, captions_path):
            captions_dict = {}
            with open(captions_path, 'r') as f:
                next(f) # Skip header
                for line in f:
                    parts = line.strip().split('\t') 
                    if len(parts) == 4:
                        source, category, img_name, caption = parts
                        captions_dict[img_name] = (category, caption)
            return captions_dict

    def _get_base_name(self, file_path):
            try:
                normalized_path = str(file_path).replace('\\', '/') 
                base_name_with_ext = os.path.basename(normalized_path) 
                base_name_resized = os.path.splitext(base_name_with_ext)[0]
                
                if base_name_resized.endswith('_resized'):
                    base_name = base_name_resized[:-len('_resized')]
                else:
                    base_name = base_name_resized
                return base_name 
            except Exception as e:
                print(f"ERROR in _get_base_name: {e}")
                return None

    def _find_image_path(self, img_base_name):
        for ext in ['.jpg', '.jpeg', '.png', '.JPEG']: 
            img_path = os.path.join(self.images_dir, img_base_name + ext)
            if os.path.exists(img_path):
                return img_path
        return None 

    def __len__(self):
        return len(self.eeg_dataset)

    def __getitem__(self, idx):
        # Your dataloader gives [B, T, C] -> [32, 500, 122]
        # We need to swap this for Conv1D models
        eeg_data = self.eeg_dataset[idx] # Shape [500, 122]
        eeg_tensor = torch.tensor(eeg_data).float()
        
        caption = self.captions[idx] 
        img_path = self.image_paths[idx]
        
        try:
            image = Image.open(img_path).convert('RGB')
            if self.image_transform:
                image_tensor = self.image_transform(image)
            else:
                image_tensor = transforms.ToTensor()(image)
        except Exception as e:
            print(f"Error loading image {img_path}: {e}. Returning a dummy image.")
            image_tensor = torch.zeros(3, 224, 224) 

        # 4. Return all 4 items
        category = self.categories[idx]
        return eeg_tensor, image_tensor, caption, category

In [2]:
import os
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

# --- 1. Define Your Paths ---
# (Update these paths to match your system)
BIDS_ROOT = '/ocean/projects/cis250019p/gandotra/11785-gp-eeg/ds005589'
IMAGE_DIR = '/ocean/projects/cis250019p/gandotra/11785-gp-eeg/images'
CAPTIONS_FILE = '/ocean/projects/cis250019p/gandotra/11785-gp-eeg/captions.txt'

# --- 2. Define Your Subject List ---
ALL_SUBJECTS = ['sub-02', 'sub-03', 'sub-05', 'sub-09', 'sub-14', 'sub-15', 
                'sub-17', 'sub-19', 'sub-20', 'sub-23', 'sub-24', 'sub-28', 'sub-29']

# --- 3. Define Image Transforms (e.g., for CLIP) ---
# (You would get the specific transforms from your model)
image_transforms = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# --- 4. Create the 3 Datasets (Train/Val/Test) ---
# This perfectly follows the paper's "split by session" rule.

print("Creating Training Dataset...")
train_dataset = EEGMultimodalDataset(
    bids_root=BIDS_ROOT,
    images_dir=IMAGE_DIR,
    captions_path=CAPTIONS_FILE,
    subject_list=ALL_SUBJECTS,
    session_list=['ses-01', 'ses-02', 'ses-03'], # 3 sessions for training
    image_transform=image_transforms
)

print("\nCreating Validation Dataset...")
val_dataset = EEGMultimodalDataset(
    bids_root=BIDS_ROOT,
    images_dir=IMAGE_DIR,
    captions_path=CAPTIONS_FILE,
    subject_list=ALL_SUBJECTS,
    session_list=['ses-04'], # 1 session for validation
    image_transform=image_transforms
)

print("\nCreating Test Dataset...")
test_dataset = EEGMultimodalDataset(
    bids_root=BIDS_ROOT,
    images_dir=IMAGE_DIR,
    captions_path=CAPTIONS_FILE,
    subject_list=ALL_SUBJECTS,
    session_list=['ses-05'], # 1 session for testing
    image_transform=image_transforms
)

# --- 5. Create PyTorch DataLoaders ---
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# --- 6. Test the loader ---
print("\nTesting the training loader...")
eeg_batch, image_batch, caption_batch, category_batch = next(iter(train_loader)) 

print(f"EEG batch shape:   {eeg_batch.shape}")
print(f"Image batch shape: {image_batch.shape}")
print(f"Caption batch (first item): '{caption_batch[0]}'")
print(f"Category batch (first item): '{category_batch[0]}'")

Creating Training Dataset...
Initializing dataset... This may take a moment.
Loading captions from /ocean/projects/cis250019p/gandotra/11785-gp-eeg/captions.txt...
Loaded 9825 captions.
Found 15600 total aligned trials.
Dataset initialization complete.

Creating Validation Dataset...
Initializing dataset... This may take a moment.
Loading captions from /ocean/projects/cis250019p/gandotra/11785-gp-eeg/captions.txt...
Loaded 9825 captions.
Found 5200 total aligned trials.
Dataset initialization complete.

Creating Test Dataset...
Initializing dataset... This may take a moment.
Loading captions from /ocean/projects/cis250019p/gandotra/11785-gp-eeg/captions.txt...
Loaded 9825 captions.
Found 5200 total aligned trials.
Dataset initialization complete.

Testing the training loader...
EEG batch shape:   torch.Size([32, 500, 122])
Image batch shape: torch.Size([32, 3, 224, 224])
Caption batch (first item): 'Diningtable with checkered cloth and dishes'
Category batch (first item): 'diningtable'

In [3]:
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
import wandb  # <-- Import W&B

# --- Login to W&B Automatically ---
# Hardcoding your API key for automatic login
wandb.login(key="be570aff6d4f4fd5239571214e49fb3e718f29c8")

# --- Define Device ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# --- Build Class Label Mapping ---
print("Building class-to-index mapping...")
all_cats = sorted(list(set(train_dataset.categories)))
num_classes = len(all_cats)
label_to_index = {label: i for i, label in enumerate(all_cats)}
index_to_label = {i: label for label, i in label_to_index.items()}
print(f"Found {num_classes} unique classes.")


# --- Model Definitions (from repo, fixed) ---

class ModelFC(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim):
        super(ModelFC, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.BatchNorm1d(num_features=hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, out_dim),
            # nn.Softmax(dim=-1) # <-- DELETED (This is the critical bug fix)
        )  
    def forward(self, x):
        # x shape is [bs, time, electrode] -> [32, 500, 122]
        # Flatten for FC layer
        x_flat = x.reshape(x.shape[0], -1)
        return self.model(x_flat)

class ModelConv(nn.Module):
    def __init__(self, electrode_num=122, class_num=20,
                 ch1=128, ch2=256, ch3=512, 
                 kernal1=3, kernal2=3, kernal3=3):  
        super(ModelConv, self).__init__()     
        
        self.model_conv = nn.Sequential(
            nn.Conv1d(in_channels=electrode_num, out_channels=ch1, kernel_size=kernal1),
            nn.BatchNorm1d(num_features=ch1), nn.ReLU(),
            nn.AvgPool1d(kernel_size=2, stride=2),
            nn.Conv1d(in_channels=ch1, out_channels=ch2, kernel_size=kernal2),
            nn.BatchNorm1d(num_features=ch2), nn.ReLU(),
            nn.AvgPool1d(kernel_size=2, stride=2),
            nn.Conv1d(in_channels=ch2, out_channels=ch3, kernel_size=kernal3),
            nn.BatchNorm1d(num_features=ch3), nn.ReLU(),
            nn.AvgPool1d(kernel_size=2, stride=2),
        )
        
        # --- FIX: Calculate fc_in_dim automatically ---
        dummy_input = torch.randn(1, 500, electrode_num) # (1, 500, 122)
        dummy_transposed = dummy_input.transpose(dim0=1, dim1=2) # (1, 122, 500)
        conv_out = self.model_conv(dummy_transposed)
        fc_in_dim = conv_out.reshape(1, -1).shape[1]
        print(f"ModelConv: Calculated fc_in_dim = {fc_in_dim}")
        # --- End Fix ---

        self.model_fc = nn.Sequential(
            nn.Linear(fc_in_dim, class_num),
            # nn.Softmax(dim=-1) # <-- DELETED
        )
        
    def forward(self, x):
        x_transpose = x.transpose(dim0=1, dim1=2) # -> [32, 122, 500]
        conv_out = self.model_conv(x_transpose)
        bs = conv_out.shape[0]
        fc_in = conv_out.reshape([bs,-1])
        return self.model_fc(fc_in)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /jet/home/pbhuyan/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mpbhuyan[0m ([33mpbhuyan-carnegie-mellon-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using device: cuda
Building class-to-index mapping...
Found 20 unique classes.


In [9]:
# --- Task 1: Train ModelFC Baseline with W&B ---

# --- 1. Define Run Configuration ---
# All hyperparameters go here
config = {
    "model_name": "ModelFC_baseline",
    "epochs": 1,
    "lr": 5e-5,
    "batch_size": 32,
    "hidden_dim": 256,
    "in_dim": 500 * 122,
    "num_classes": num_classes,
}

# Define a unique ID for this run, so you can resume it
RUN_ID = "modelfc_baseline_run_1"
CKPT_PATH = f"./{RUN_ID}.pth"

# --- 2. Initialize W&B ---
run = wandb.init(
    project="eeg-classification", # Name of your project
    job_type="train",
    config=config,
    id=RUN_ID,        # Set a fixed ID for this run
    resume="allow",   # Allow resuming if this ID exists
)

# --- 3. Instantiate Model, Optimizer, Loss ---
model_fc = ModelFC(config["in_dim"], config["hidden_dim"], config["num_classes"]).to(DEVICE)
optimizer = optim.Adam(model_fc.parameters(), lr=config["lr"])
criterion = nn.CrossEntropyLoss()

# --- 4. Load Checkpoint if Resuming ---
start_epoch = 0
if wandb.run.resumed:
    print(f"Resuming run {RUN_ID}...")
    try:
        print(f"Attempting to load checkpoint from: {CKPT_PATH}")
        checkpoint = torch.load(CKPT_PATH)
        model_fc.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resumed successfully. Starting from epoch {start_epoch}")
    except FileNotFoundError:
        print("No checkpoint file found. Starting from scratch.")
    # --- THIS BLOCK WAS MOVED INSIDE THE 'if' STATEMENT ---
    except Exception as e:
        print(f"Error loading checkpoint: {e}. Starting from scratch.")

# Tell W&B to watch the model
wandb.watch(model_fc, criterion, log="all", log_freq=100)

# --- 5. Training Loop ---
print(f"--- Starting Training for {config['model_name']} from Epoch {start_epoch+1} ---")

# Outer progress bar for epochs
epoch_bar = tqdm(range(start_epoch, config["epochs"]), desc="Epochs")

for epoch in epoch_bar:
    
    # --- Training Phase ---
    model_fc.train()
    train_loss = 0.0
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1:03} Training", leave=False)
    
    for eeg_batch, _, _, category_batch in train_bar: 
        eeg_batch = eeg_batch.to(DEVICE)
        labels = torch.tensor([label_to_index[cat] for cat in category_batch], dtype=torch.long).to(DEVICE)
        
        optimizer.zero_grad()
        logits = model_fc(eeg_batch)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_bar.set_postfix(loss=f"{loss.item():.4f}")
    
    avg_train_loss = train_loss / len(train_loader)

    # --- Validation Phase ---
    model_fc.eval()
    val_loss = 0.0
    all_preds = []
    all_labels = []
    val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1:03} Validation", leave=False)
    
    with torch.no_grad():
        for eeg_batch, _, _, category_batch in val_bar:
            eeg_batch = eeg_batch.to(DEVICE)
            labels = torch.tensor([label_to_index[cat] for cat in category_batch], dtype=torch.long).to(DEVICE)
            
            logits = model_fc(eeg_batch)
            loss = criterion(logits, labels)
            val_loss += loss.item()
            
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = accuracy_score(all_labels, all_preds)
    
    # --- 6. Log Metrics to W&B ---
    wandb.log({
        "epoch": epoch,
        "train_loss": avg_train_loss,
        "val_loss": avg_val_loss,
        "val_accuracy": val_accuracy
    })
    
    # Update the main epoch bar
    epoch_bar.set_postfix(
        Train_Loss=f"{avg_train_loss:.4f}", 
        Val_Loss=f"{avg_val_loss:.4f}", 
        Val_Acc=f"{val_accuracy*100:.2f}%"
    )

    # --- 7. Save Checkpoint ---
    torch.save({
        'epoch': epoch,
        'model_state_dict': model_fc.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_val_loss,
    }, CKPT_PATH)

# --- 8. Finish the W&B Run ---
wandb.finish()
print(f"--- {config['model_name']} training complete. ---")

Resuming run modelfc_baseline_run_1...
Attempting to load checkpoint from: ./modelfc_baseline_run_1.pth
Resumed successfully. Starting from epoch 100
--- Starting Training for ModelFC_baseline from Epoch 101 ---


Epochs: 0it [00:00, ?it/s]

0,1
epoch,99.0
train_loss,0.40934
val_accuracy,0.07
val_loss,4.75222


--- ModelFC_baseline training complete. ---


In [8]:
!pip install ipywidgets

Collecting ipywidgets
  Downloading ipywidgets-8.1.7-py3-none-any.whl.metadata (2.4 kB)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets)
  Downloading widgetsnbextension-4.0.14-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab_widgets~=3.0.15 (from ipywidgets)
  Downloading jupyterlab_widgets-3.0.15-py3-none-any.whl.metadata (20 kB)
Downloading ipywidgets-8.1.7-py3-none-any.whl (139 kB)
Downloading jupyterlab_widgets-3.0.15-py3-none-any.whl (216 kB)
Downloading widgetsnbextension-4.0.14-py3-none-any.whl (2.2 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m MB/s[0m eta [36m0:00:01[0m
[?25hInstalling collected packages: widgetsnbextension, jupyterlab_widgets, ipywidgets
Successfully installed ipywidgets-8.1.7 jupyterlab_widgets-3.0.15 widgetsnbextension-4.0.14

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;4

In [10]:
# --- Task 2a: Image-Caption Retrieval Baseline ---

from transformers import CLIPProcessor, CLIPModel

# 1. Load frozen CLIP Model and Processor
print("Loading CLIP model...")
clip_model_name = "openai/clip-vit-base-patch32"
clip_model = CLIPModel.from_pretrained(clip_model_name).to(DEVICE)
clip_processor = CLIPProcessor.from_pretrained(clip_model_name)

# Ensure the model is frozen and in evaluation mode
clip_model.eval()
for param in clip_model.parameters():
    param.requires_grad = False
print("CLIP model loaded and frozen.")

# 2. Calculate Embeddings for the Validation Set
# We'll store all image and text embeddings from the val_loader
all_img_embeds = []
all_txt_embeds = []
print("Calculating image and text embeddings for the validation set...")

# Use tqdm for progress
val_embed_bar = tqdm(val_loader, desc="Calculating Embeddings")

with torch.no_grad():
    # Note: We only need image_batch and caption_batch here
    for _, image_batch, caption_batch, _ in val_embed_bar: 
        
        image_batch = image_batch.to(DEVICE)
        
        # Process text captions using the CLIP processor
        # Important: Convert tuple of captions to a list for the processor
        text_inputs = clip_processor(
            text=list(caption_batch), 
            return_tensors="pt", 
            padding=True, 
            truncation=True
        ).to(DEVICE)
        
        # Get embeddings from CLIP
        img_embeds = clip_model.get_image_features(image_batch)
        txt_embeds = clip_model.get_text_features(**text_inputs)
        
        # Normalize embeddings (standard practice for CLIP similarity)
        img_embeds /= img_embeds.norm(dim=-1, keepdim=True)
        txt_embeds /= txt_embeds.norm(dim=-1, keepdim=True)
        
        all_img_embeds.append(img_embeds.cpu()) # Move to CPU to save GPU memory
        all_txt_embeds.append(txt_embeds.cpu())

# Concatenate all batch embeddings into single tensors
all_img_embeds = torch.cat(all_img_embeds).to(DEVICE) # Move back to GPU for similarity calc
all_txt_embeds = torch.cat(all_txt_embeds).to(DEVICE)

print(f"\nFinished calculating embeddings.")
print(f"Image embeds shape: {all_img_embeds.shape}") # Should be [5200, 512]
print(f"Text embeds shape:  {all_txt_embeds.shape}") # Should be [5200, 512]

# 3. Calculate Retrieval Accuracy (Recall@k)
print("\nCalculating retrieval accuracy...")
# Calculate the similarity matrix (cosine similarity)
# Since embeddings are normalized, matmul is equivalent to cosine similarity
similarity_matrix = all_img_embeds @ all_txt_embeds.T

# Helper function to calculate Recall@k
def calculate_recall(sim_matrix, k_values=(1, 5, 10)):
    """Calculates Recall@k for multiple k values."""
    n = len(sim_matrix)
    targets = torch.arange(n).to(sim_matrix.device)
    
    # Get the indices of the top-k most similar text embeddings for each image
    _, topk_indices = sim_matrix.topk(max(k_values), dim=1)
    
    recalls = {}
    for k in k_values:
        # Check if the correct target index is within the top-k predictions
        correct_at_k = topk_indices[:, :k].eq(targets.view(-1, 1)).any(dim=1)
        recalls[k] = (correct_at_k.sum() / n).item()
    return recalls

# Calculate and print recalls
recall_results = calculate_recall(similarity_matrix, k_values=(1, 5, 10))

print(f"\n--- CLIP Baseline Results (Image-to-Text Retrieval on Validation Set) ---")
for k, recall in recall_results.items():
      print(f"Recall@{k:<2}: {recall*100:.2f}%")

Loading CLIP model...
CLIP model loaded and frozen.
Calculating image and text embeddings for the validation set...


Calculating Embeddings:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av


Finished calculating embeddings.
Image embeds shape: torch.Size([5200, 512])
Text embeds shape:  torch.Size([5200, 512])

Calculating retrieval accuracy...

--- CLIP Baseline Results (Image-to-Text Retrieval on Validation Set) ---
Recall@1 : 21.96%
Recall@5 : 49.87%
Recall@10: 61.37%


In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
import wandb
from transformers import CLIPProcessor, CLIPModel # We need CLIP again

# --- Define the EEG-Caption Model ---
class EEGCaptionModel(nn.Module):
    """
    Combines an EEG backbone (like ModelFC or ModelConv)
    with a projection head to map EEG features into CLIP's space (512 dims).
    """
    def __init__(self, eeg_backbone, output_dim=512):
        super().__init__()
        self.backbone = eeg_backbone
        
        # --- Determine backbone output size ---
        # Get the output dimension from the backbone's final layer
        # Assumes the backbone ends with a nn.Linear layer
        if isinstance(eeg_backbone, ModelFC):
            # ModelFC's last layer is nn.Linear(hid_dim, out_dim)
            backbone_out_dim = eeg_backbone.model[-1].in_features 
        elif isinstance(eeg_backbone, ModelConv):
             # ModelConv's last layer is nn.Linear(fc_in_dim, class_num)
            backbone_out_dim = eeg_backbone.model_fc[0].in_features
        # Add checks for ModelLSTM, ModelTransformer if you use them
        else:
            raise TypeError("Unsupported backbone type. Add logic to get output dim.")
        # ------------------------------------

        # Projection head maps backbone output to CLIP's dimension
        self.projection_head = nn.Linear(backbone_out_dim, output_dim)
        
    def forward(self, eeg):
        # Pass EEG through the backbone
        # Need to handle different backbone types if their forward needs adjustment
        if isinstance(self.backbone, ModelFC):
            # ModelFC expects [B, T, C] and flattens it
             # Need to get features *before* the final classification layer
            features = self.backbone.model[:-1](eeg.reshape(eeg.shape[0], -1)) # Get output of ReLU
        elif isinstance(self.backbone, ModelConv):
            # ModelConv expects [B, C, T] - Needs adjustment based on YOUR data!
            # Your dataloader gives [B, T, C] -> [32, 500, 122]
            x_transpose = eeg.transpose(dim0=1, dim1=2) # -> [32, 122, 500]
            conv_out = self.backbone.model_conv(x_transpose)
            bs = conv_out.shape[0]
            features = conv_out.reshape([bs,-1]) # Features before final Linear
            # Pass features through the linear layer of the Conv model's fc part, but NOT softmax
            features = self.backbone.model_fc[0](features) # Output of the Linear layer

        else:
             raise TypeError("Unsupported backbone type for feature extraction.")

        # Pass features through the projection head
        eeg_embeds = self.projection_head(features) # [B, 512]
        return eeg_embeds

# --- Contrastive Loss Function ---
def contrastive_loss(eeg_embeds, txt_embeds, temperature=0.07):
    # Normalize embeddings (important for cosine similarity)
    eeg_embeds = eeg_embeds / eeg_embeds.norm(dim=-1, keepdim=True)
    txt_embeds = txt_embeds / txt_embeds.norm(dim=-1, keepdim=True)
    
    # Calculate cosine similarity matrix (logit scale)
    # Higher temperature -> softer probabilities, lower -> harder
    logit_scale = clip_model.logit_scale.exp() # Use CLIP's learned scale
    sim_matrix = (eeg_embeds @ txt_embeds.T) * logit_scale
    
    # Ground truth: diagonal elements are the correct pairs
    labels = torch.arange(len(sim_matrix)).to(sim_matrix.device)
    
    # Calculate CrossEntropyLoss in both directions
    loss_eeg_to_txt = nn.CrossEntropyLoss()(sim_matrix, labels)
    loss_txt_to_eeg = nn.CrossEntropyLoss()(sim_matrix.T, labels)
    
    # Average the two losses
    return (loss_eeg_to_txt + loss_txt_to_eeg) / 2.0

# --- Helper function for Recall@k (same as before) ---
def calculate_recall(sim_matrix, k_values=(1, 5, 10)):
    n = len(sim_matrix)
    targets = torch.arange(n).to(sim_matrix.device)
    _, topk_indices = sim_matrix.topk(max(k_values), dim=1)
    recalls = {}
    for k in k_values:
        correct_at_k = topk_indices[:, :k].eq(targets.view(-1, 1)).any(dim=1)
        recalls[k] = (correct_at_k.sum() / n).item()
    return recalls


# --- Task 2c: Training Loop ---

# --- 1. Define Run Configuration ---
config_retrieval = {
    "model_name": "EEG_Caption_Retrieval_ModelFC",
    "epochs": 50, # Retrieval might need fewer/more epochs than classification
    "lr": 1e-4, # May need tuning
    "batch_size": 32, # From your dataloader
    "backbone": "ModelFC", # Specify which backbone you're using
    "clip_model_name": "openai/clip-vit-base-patch32",
    "contrastive_temperature": 0.07, # Common default for CLIP
}

RUN_ID_RETRIEVAL = "eeg_retrieval_modelfc_run_1"
CKPT_PATH_RETRIEVAL = f"./{RUN_ID_RETRIEVAL}.pth"

# --- 2. Initialize W&B ---
run_retrieval = wandb.init(
    project="eeg-caption-retrieval", # New project name
    entity="pbhuyan-carnegie-mellon-university", # *** YOUR W&B ENTITY ***
    job_type="train",
    config=config_retrieval,
    id=RUN_ID_RETRIEVAL,
    resume="allow",
)

# --- 3. Instantiate Models ---
#    a) Load your trained EEG classification backbone (e.g., ModelFC)
#       OR instantiate a new one if you want to train from scratch.
#       Let's assume you want to use the already trained 'model_fc'
eeg_backbone = model_fc # Using the model trained in the previous cell

#    b) Create the full EEG-Caption model
eeg_retrieval_model = EEGCaptionModel(eeg_backbone).to(DEVICE)

#    c) Load the FROZEN CLIP model (for text embeddings and logit scale)
#       (We already loaded 'clip_model' and 'clip_processor' in the previous cell)
clip_model.eval()
for param in clip_model.parameters():
    param.requires_grad = False

# --- 4. Define Optimizer ---
# IMPORTANT: Only optimize the parameters of the EEG retrieval model
optimizer = optim.Adam(eeg_retrieval_model.parameters(), lr=config_retrieval["lr"])

# --- 5. Load Checkpoint if Resuming ---
start_epoch_retrieval = 0
if wandb.run.resumed:
    print(f"Resuming retrieval run {RUN_ID_RETRIEVAL}...")
    try:
        checkpoint = torch.load(CKPT_PATH_RETRIEVAL)
        eeg_retrieval_model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch_retrieval = checkpoint['epoch'] + 1
        print(f"Resumed successfully. Starting from epoch {start_epoch_retrieval}")
    except FileNotFoundError:
        print("No retrieval checkpoint file found.")
    except Exception as e:
        print(f"Error loading retrieval checkpoint: {e}.")

wandb.watch(eeg_retrieval_model, contrastive_loss, log="all", log_freq=100)

# --- 6. Retrieval Training Loop ---
print(f"--- Starting EEG-Caption Retrieval Training from Epoch {start_epoch_retrieval+1} ---")
epoch_bar_retrieval = tqdm(range(start_epoch_retrieval, config_retrieval["epochs"]), desc="Retrieval Epochs")

for epoch in epoch_bar_retrieval:
    
    # --- Training Phase ---
    eeg_retrieval_model.train()
    train_loss = 0.0
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1:03} Training", leave=False)
    
    for eeg_batch, _, caption_batch, _ in train_bar: 
        eeg_batch = eeg_batch.to(DEVICE)
        
        # 1. Get EEG embeddings from your trainable model
        eeg_embeds = eeg_retrieval_model(eeg_batch)
        
        # 2. Get Text embeddings from the FROZEN CLIP model
        with torch.no_grad():
            text_inputs = clip_processor(
                text=list(caption_batch), return_tensors="pt", padding=True, truncation=True
            ).to(DEVICE)
            txt_embeds = clip_model.get_text_features(**text_inputs)
        
        # 3. Calculate Contrastive Loss
        optimizer.zero_grad()
        loss = contrastive_loss(eeg_embeds, txt_embeds, temperature=config_retrieval["contrastive_temperature"])
        
        # 4. Backpropagate and update weights (only affects EEG model)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_bar.set_postfix(loss=f"{loss.item():.4f}")
        
    avg_train_loss = train_loss / len(train_loader)

    # --- Validation Phase ---
    eeg_retrieval_model.eval()
    val_loss = 0.0
    all_eeg_embeds = []
    all_txt_embeds = []
    val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1:03} Validation", leave=False)
    
    with torch.no_grad():
        for eeg_batch, _, caption_batch, _ in val_bar:
            eeg_batch = eeg_batch.to(DEVICE)
            
            # 1. Get EEG embeddings
            eeg_embeds = eeg_retrieval_model(eeg_batch)
            
            # 2. Get Text embeddings
            text_inputs = clip_processor(
                text=list(caption_batch), return_tensors="pt", padding=True, truncation=True
            ).to(DEVICE)
            txt_embeds = clip_model.get_text_features(**text_inputs)

            # Calculate loss (optional, but good for monitoring)
            loss = contrastive_loss(eeg_embeds, txt_embeds, temperature=config_retrieval["contrastive_temperature"])
            val_loss += loss.item()

            # Normalize embeddings for recall calculation
            eeg_embeds /= eeg_embeds.norm(dim=-1, keepdim=True)
            txt_embeds /= txt_embeds.norm(dim=-1, keepdim=True)
            
            all_eeg_embeds.append(eeg_embeds.cpu())
            all_txt_embeds.append(txt_embeds.cpu())

    avg_val_loss = val_loss / len(val_loader)
    
    # Calculate Recall@k on validation set
    all_eeg_embeds = torch.cat(all_eeg_embeds).to(DEVICE)
    all_txt_embeds = torch.cat(all_txt_embeds).to(DEVICE)
    val_sim_matrix = all_eeg_embeds @ all_txt_embeds.T
    recall_results = calculate_recall(val_sim_matrix, k_values=(1, 5))
    
    # --- 7. Log Metrics to W&B ---
    wandb.log({
        "epoch": epoch,
        "train_loss": avg_train_loss,
        "val_loss": avg_val_loss,
        "val_recall_at_1": recall_results[1],
        "val_recall_at_5": recall_results[5],
    })
    
    # Update the main epoch bar
    epoch_bar_retrieval.set_postfix(
        Train_Loss=f"{avg_train_loss:.4f}", 
        Val_Loss=f"{avg_val_loss:.4f}", 
        Val_R_at_1=f"{recall_results[1]*100:.2f}%",
        Val_R_at_5=f"{recall_results[5]*100:.2f}%"
    )

    # --- 8. Save Checkpoint ---
    torch.save({
        'epoch': epoch,
        'model_state_dict': eeg_retrieval_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_val_loss, # Save validation loss
    }, CKPT_PATH_RETRIEVAL)

# --- 9. Finish the W&B Run ---
wandb.finish()
print(f"--- {config_retrieval['model_name']} training complete. ---")

Resuming retrieval run eeg_retrieval_modelfc_run_1...
No retrieval checkpoint file found.
--- Starting EEG-Caption Retrieval Training from Epoch 1 ---


Retrieval Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 001 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 001 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 002 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 002 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 003 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 003 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 004 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 004 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 005 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 005 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 006 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 006 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 007 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 007 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 008 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 008 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 009 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 009 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 010 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 010 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 011 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 011 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 012 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 012 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 013 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 013 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 014 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 014 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 015 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 015 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 016 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 016 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 017 Training:   0%|          | 0/488 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch 017 Validation:   0%|          | 0/163 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

KeyboardInterrupt: 