In [1]:
# Install only what's needed
!pip install torch-geometric -q
!pip install rdkit
import torch
from torch.utils.data import IterableDataset
from torch_geometric.datasets import MoleculeNet
import pandas as pd
import itertools
import sys
from google.colab import drive
import os

gdrive_path='/content/gdrive/MyDrive/Project_HIV'

# This will mount your google drive under 'MyDrive'
drive.mount('/content/gdrive', force_remount=True)
# In order to access the files in this notebook we have to navigate to the correct folder
os.chdir(gdrive_path)
# Check manually if all files are present
print(sorted(os.listdir()))


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m27.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting rdkit
  Downloading rdkit-2025.3.5-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (4.1 kB)
Downloading rdkit-2025.3.5-cp312-cp312-manylinux_2_28_x86_64.whl (36.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.2/36.2 MB[0m [31m40.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-2025.3.5
Mounted at /content/gdrive
['.git', '.gitignore', 'HIV.ipynb', 'README.md', '__pycache__', 'checkpoint_batch_2000.pt', 'code', 'data_import.py', 'datasets', 'hiv_train.jsonl', 'hiv_val.jsonl', 'pythia_pe_peft.py', 'utils.py']


In [2]:
# Add code directory to path
from torch_geometric.datasets import MoleculeNet
# sys.path.insert(0, os.path.join(os.getcwd(), 'code'))

# # Import the Colab-friendly version
# from data.dataset_colab import PartialHIVDataset

class PartialHIVDataset(IterableDataset):
    """
    IterableDataset that loads only a small portion of the HIV dataset.
    """

    def __init__(self, root='/tmp/HIV', max_samples=10):
        """
        Args:
            root: Root directory for dataset storage
            max_samples: Maximum number of samples to load
        """
        self.root = root
        self.max_samples = max_samples
        self._dataset = None

    def _lazy_load_dataset(self):
        """Lazy load the dataset only when iteration begins"""
        if self._dataset is None:
            print(f"Initializing HIV dataset (will only load {self.max_samples} samples)...")
            # Import here to ensure it's available

            self._dataset = MoleculeNet(root=self.root, name='HIV')
            print(f"Dataset ready! Total size: {len(self._dataset)} molecules")
            print(f"But we'll only load {self.max_samples} of them.\n")

    def parse_molecules(self):
        """
        Parse molecules from the dataset, stopping after max_samples.
        """
        self._lazy_load_dataset()

        for i in range(min(self.max_samples, len(self._dataset))):
            data = self._dataset[i]

            # Extract molecule information
            mol_info = {
                'source': f"molecule_{i}",  # Similar to your CustomIterableDataset
                'target': data.y.item(),    # HIV activity label
                'num_atoms': data.num_nodes,
                'num_bonds': data.num_edges // 2,  # Undirected edges
                'smiles': getattr(data, 'smiles', 'N/A')
            }

            yield mol_info

    def __iter__(self):
        """Iterator with worker support"""
        iterator = self.parse_molecules()
        worker_info = torch.utils.data.get_worker_info()

        if worker_info is not None:
            worker_total_num = worker_info.num_workers
            worker_id = worker_info.id
            return itertools.islice(iterator, worker_id, None, worker_total_num)

        return iterator


# Your code
batch_list = []
batch_size = 40

# Create dataset
dataset = PartialHIVDataset(max_samples=batch_size)
iterator = iter(dataset)
full_batch_data = list(dataset)

df = pd.DataFrame(full_batch_data)

Initializing HIV dataset (will only load 40 samples)...


Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv
Processing...
Done!


Dataset ready! Total size: 41120 molecules
But we'll only load 40 of them.



In [3]:
import codecs

import numpy as np
import torch.nn as nn
from transformers import AutoTokenizer

smiles_list = []
batch_size = 40


# Iterate through the dataset and fill it with the SMILES strings
for _ in range(batch_size):
  mol_info = next(iterator)
  smiles_list.append(mol_info['smiles'])



# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

# First, tokenize without padding to find the actual max length
tokenized_lengths = []
for smiles in smiles_list:
    tokens = tokenizer(smiles, padding=False, truncation=False)
    tokenized_lengths.append(len(tokens['input_ids']))

# Determine max_length from your data
max_length = max(tokenized_lengths)
print(f"Maximum sequence length in data: {max_length}")


# Now tokenize with the data-driven max_length
encoded = tokenizer(
    smiles_list,
    padding='max_length',
    truncation=True,
    max_length=max_length,
    return_tensors='pt'
)

# Create embedding layer
embedding = nn.Embedding(
    num_embeddings=len(tokenizer),
    embedding_dim=512,
    padding_idx=tokenizer.pad_token_id
)

# Get embeddings directly
embeddings = embedding(encoded['input_ids'])
print(f"Embeddings shape: {embeddings.shape}")  # (batch_size, max_length, 512)



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/166 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/501 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/150 [00:00<?, ?B/s]

Maximum sequence length in data: 61
Embeddings shape: torch.Size([40, 61, 512])


In [4]:
from scipy.sparse.linalg import eigsh
from rdkit import Chem
def smiles_to_adjacency_matrix(smiles):
    """Convert SMILES to adjacency matrix."""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    n_atoms = mol.GetNumAtoms()
    adj_matrix = np.zeros((n_atoms, n_atoms))

    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        # Undirected graph
        adj_matrix[i, j] = 1
        adj_matrix[j, i] = 1

    return adj_matrix

def compute_graph_positional_encoding(adj_matrix, k=30):
    """
    Compute graph positional encodings from eigenvectors of the
    symmetrically normalized graph Laplacian.
    """
    n = adj_matrix.shape[0]

    # Compute degree matrix
    degree = np.sum(adj_matrix, axis=1)
    degree[degree == 0] = 1

    # D^(-1/2)
    d_inv_sqrt = np.diag(1.0 / np.sqrt(degree))

    # Symmetrically normalized Laplacian
    identity = np.eye(n)
    normalized_adj = d_inv_sqrt @ adj_matrix @ d_inv_sqrt
    laplacian = identity - normalized_adj

    # Compute eigenvectors
    if n < k:
        # If graph has fewer nodes than k, pad with zeros
        eigenvalues, eigenvectors = np.linalg.eigh(laplacian)
        # Pad eigenvectors to have k columns
        padded_eigenvectors = np.zeros((n, k))
        padded_eigenvectors[:, :n] = eigenvectors
        return padded_eigenvectors
    else:
        eigenvalues, eigenvectors = eigsh(laplacian, k=k, which='SM')
        return eigenvectors

In [5]:
# Process all SMILES and compute their graph PEs
graph_pes_list = []
max_nodes = 0

for smiles in smiles_list:
    adj_matrix = smiles_to_adjacency_matrix(smiles)
    if adj_matrix is not None:
        pe = compute_graph_positional_encoding(adj_matrix, k=30)
        graph_pes_list.append(pe)
        max_nodes = max(max_nodes, pe.shape[0])
    else:
        graph_pes_list.append(None)

print(f"Maximum number of atoms in dataset: {max_nodes}")


Maximum number of atoms in dataset: 42


In [6]:
import torch
import numpy as np
from rdkit import Chem
import re

# def generate_random_orthonormal_pe(n_vectors, dim=30):
#     """Generate n_vectors random orthonormal vectors of dimension dim."""
#     if n_vectors == 0:
#         return np.zeros((0, dim))

#     # Generate random matrix
#     random_matrix = np.random.randn(dim, n_vectors)

#     # Use QR decomposition to get orthonormal vectors
#     Q, _ = np.linalg.qr(random_matrix)

#     # Return first n_vectors columns (transposed to have shape (n_vectors, dim))
#     return Q[:, :n_vectors].T

def generate_zero_pe(n_vectors, dim=30):
    """Generate n_vectors zero vectors of dimension dim for non-atom characters."""
    return np.zeros((n_vectors, dim))

def parse_token_components(token):
    """
    Parse a token to identify atoms and characters.
    Returns: (atom_indices, n_characters)
    """
    # Common atom patterns in SMILES
    atom_pattern = r'(Cl|Br|Si|Mg|Ca|Fe|Al|Na|Li|[BCNOFPSKHIV])'

    # Find all atoms in the token
    atoms = re.findall(atom_pattern, token)

    # Count non-atom characters
    # Remove atoms from token to count remaining characters
    remaining = token
    for atom in atoms:
        remaining = remaining.replace(atom, '', 1)

    # Count actual characters (digits, +, -, =, #, etc.)
    n_characters = len([c for c in remaining])

    return atoms, n_characters

def align_graph_pe_to_tokens(smiles, graph_pe, tokenizer, max_length,random_seed=None):
    """
    Align graph PE to tokens, handling pure characters, atoms, and mixed tokens.

    Args:
        smiles: SMILES string
        graph_pe: Graph positional encoding for atoms (n_atoms, embedding_dim)
        tokenizer: Tokenizer object
        max_length: Maximum sequence length
        random_seed: Random seed for reproducible random PEs
    """
    if random_seed is not None:
        np.random.seed(random_seed)
    # Determine embedding dimension
    if graph_pe is not None:
        embedding_dim = graph_pe.shape[1]
    elif embedding_dim is None:
        embedding_dim = 30  # Default fallback

    mol = Chem.MolFromSmiles(smiles)
    if mol is None or graph_pe is None:
        return torch.zeros(max_length, embedding_dim)

    # Get tokens
    encoding = tokenizer(smiles, padding='max_length', max_length=max_length)
    tokens = tokenizer.convert_ids_to_tokens(encoding['input_ids'])


    # Initialize token PE
    token_pe = torch.zeros(max_length, embedding_dim)

    # Build atom mapping from SMILES
    atom_symbols = [atom.GetSymbol() for atom in mol.GetAtoms()]
    atom_count = {symbol: 0 for symbol in set(atom_symbols)}
    atom_to_idx = {}

    for idx, symbol in enumerate(atom_symbols):
        atom_to_idx[(symbol, atom_count[symbol])] = idx
        atom_count[symbol] += 1

    # Reset atom count for tracking
    current_atom_count = {symbol: 0 for symbol in atom_count}

    # Process each token
    for i, token in enumerate(tokens):
        if token in ['<s>', '</s>', '<pad>']:
            continue

        # Parse token components
        atoms_in_token, n_characters = parse_token_components(token)

        if len(atoms_in_token) == 0 and n_characters > 0:
            # Pure character token - use random orthonormal PE
            random_pes = generate_zero_pe(n_characters, embedding_dim)
            token_pe[i] = torch.tensor(random_pes.sum(axis=0))

        elif len(atoms_in_token) > 0 and n_characters == 0:
            # Pure atom token(s) - use graph PE
            atom_pes = []
            for atom_symbol in atoms_in_token:
                if atom_symbol in current_atom_count:
                    atom_key = (atom_symbol, current_atom_count[atom_symbol])
                    if atom_key in atom_to_idx:
                        atom_idx = atom_to_idx[atom_key]
                        atom_pes.append(graph_pe[atom_idx])
                        current_atom_count[atom_symbol] += 1

            if atom_pes:
                # sum the PEs of all atoms in this token
                token_pe[i] = torch.tensor(np.sum(atom_pes, axis=0))

        elif len(atoms_in_token) > 0 and n_characters > 0:
            # Mixed token - combine atom PE and character PE
            # Get atom PEs
            atom_pes = []
            for atom_symbol in atoms_in_token:
                if atom_symbol in current_atom_count:
                    atom_key = (atom_symbol, current_atom_count[atom_symbol])
                    if atom_key in atom_to_idx:
                        atom_idx = atom_to_idx[atom_key]
                        atom_pes.append(graph_pe[atom_idx])
                        current_atom_count[atom_symbol] += 1

            # Get character PEs
            random_pes = generate_zero_pe(n_characters, embedding_dim)

            # Combine: sum of atoms + sum of characters
            combined_pe = np.zeros(embedding_dim)
            if atom_pes:
                combined_pe += np.sum(atom_pes, axis=0)
            combined_pe += random_pes.sum(axis=0)

            token_pe[i] = torch.tensor(combined_pe)

    return token_pe



In [7]:
token_pes_list = []

for smiles, graph_pe in zip(smiles_list, graph_pes_list):
    # Align graph PE to tokens using the provided function
    token_pe = align_graph_pe_to_tokens(
        smiles=smiles,
        graph_pe=graph_pe,
        tokenizer=tokenizer,
        max_length=max_length,
        random_seed=42
    )
    token_pes_list.append(token_pe)

# Stack into batch tensor
token_pes_batch = torch.stack(token_pes_list)

In [8]:
class EmbeddingWithGraphPE(nn.Module):
    def __init__(self, embed_dim=512, pe_dim=30):
        super().__init__()
        # One-layer projection with GeLU from graph PE to embedding dimension
        # Using standard Laplacian, so pe_dim is just k
        self.pe_projection = nn.Sequential(
            nn.Linear(pe_dim, embed_dim),
            nn.GELU()
        )

    def forward(self, embeddings, token_pes):
        # token_pes shape: [batch_size, seq_len, pe_dim]
        # embeddings shape: [batch_size, seq_len, embed_dim]

        # Project token PEs to embedding dimension with GeLU
        projected_pes = self.pe_projection(token_pes)

        # Add to token embeddings
        enhanced_embeddings = embeddings + projected_pes

        return enhanced_embeddings

# Usage
model = EmbeddingWithGraphPE(embed_dim=512, pe_dim=30)
# graph_pes_batch shape: [40, 61, 30] (30 eigenvectors from standard Laplacian)
# embeddings shape: [40, 61, 512]
enhanced_embeddings = model(embeddings, token_pes_batch)
# Output shape: [40, 61, 512]

In [9]:
enhanced_embeddings.shape

torch.Size([40, 61, 512])

In [10]:
import json
from data_import import check_class_balance, stratified_train_val_split, convert_to_litgpt_format, save_to_jsonl

# Create dataset
batch_size = 41120  # Adjust based on your needs
dataset = PartialHIVDataset(max_samples=batch_size)

# Convert to DataFrame
full_batch_data = list(dataset)
df = pd.DataFrame(full_batch_data)


# Check balance in the original DataFrame
original_balance = check_class_balance(df, "Original Dataset")

# Convert to LitGPT format for different task types
classification_data = convert_to_litgpt_format(df, task_type="classification")

# Check balance in the converted data
converted_balance = check_class_balance(classification_data, "Converted Dataset")

# Create stratified train/validation split
train_data, val_data = stratified_train_val_split(classification_data, train_ratio=0.8)

# Check balance in train and validation sets
train_balance = check_class_balance(train_data, "Training Set")
val_balance = check_class_balance(val_data, "Validation Set")

# Calculate difference from original distribution
train_diff_0 = abs(train_balance['ratio_0'] - original_balance['ratio_0'])
train_diff_1 = abs(train_balance['ratio_1'] - original_balance['ratio_1'])
val_diff_0 = abs(val_balance['ratio_0'] - original_balance['ratio_0'])
val_diff_1 = abs(val_balance['ratio_1'] - original_balance['ratio_1'])


# Save to JSONL files
save_to_jsonl(train_data, "hiv_train.jsonl")
save_to_jsonl(val_data, "hiv_val.jsonl")


Initializing HIV dataset (will only load 41120 samples)...
Dataset ready! Total size: 41120 molecules
But we'll only load 41120 of them.


Original Dataset Class Balance:
Total samples: 41120
Class 0.0: 39677 samples (96.49%)
Class 1.0: 1443 samples (3.51%)

Converted Dataset Class Balance:
Total samples: 41120
Class 0: 39677 samples (96.49%)
Class 1: 1443 samples (3.51%)

Training Set Class Balance:
Total samples: 32895
Class 0: 31741 samples (96.49%)
Class 1: 1154 samples (3.51%)

Validation Set Class Balance:
Total samples: 8225
Class 0: 7936 samples (96.49%)
Class 1: 289 samples (3.51%)
Saved 32895 entries to hiv_train.jsonl
Saved 8225 entries to hiv_val.jsonl


In [11]:
import json
import torch
from utils import calculate_max_lengths
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

def load_jsonl(filepath):
    """Load data from JSONL file"""
    data = []
    with open(filepath, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    return data

class HIVDataset(Dataset):
    def __init__(self, jsonl_filepath, tokenizer, max_full_length):
        self.data = load_jsonl(jsonl_filepath)
        self.tokenizer = tokenizer
        self.max_full_length = max_full_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Extract SMILES
        smiles = item['input'].replace("SMILES: ", "")

        # Create full prompt
        full_prompt = f"{item['instruction']}\n{item['input']}\nAnswer:"
        # Compute graph PE with actual SMILES length
        smiles_tokens = self.tokenizer(smiles, add_special_tokens=False)
        actual_smiles_length = len(smiles_tokens['input_ids'])
        # Get prompt length for label masking
        prompt_encoding = self.tokenizer(full_prompt, add_special_tokens=True, return_tensors='pt')
        prompt_length = prompt_encoding['input_ids'].shape[1]

        # Full sequence for training
        target_text = f"{full_prompt} {item['output']}"
        full_encoding = self.tokenizer(
            target_text,
            padding='max_length',
            max_length=self.max_full_length,
            truncation=True,
            return_tensors='pt'
        )

        # Compute graph PE for SMILES only
        adj_matrix = smiles_to_adjacency_matrix(smiles)
        if adj_matrix is not None:
            graph_pe = compute_graph_positional_encoding(adj_matrix, k=30)
            smiles_token_pe = align_graph_pe_to_tokens(
                smiles, graph_pe, self.tokenizer, actual_smiles_length, random_seed=42
            )
        else:
            smiles_token_pe = torch.zeros(actual_smiles_length, 30)

        # Create full-length PE tensor
        full_token_pe = torch.zeros(self.max_full_length, 30)

        # Find where SMILES tokens are and place PEs there
        # This is approximate - you might need better alignment
        instruction_text = f"{item['instruction']}\nSMILES: "
        instruction_tokens = self.tokenizer(instruction_text, add_special_tokens=True)
        smiles_start_idx = len(instruction_tokens['input_ids'])

        # Copy SMILES PEs to correct position
        pe_end_idx = min(smiles_start_idx + actual_smiles_length, self.max_full_length)


        pe_length = min(smiles_token_pe.shape[0], pe_end_idx - smiles_start_idx)
        full_token_pe[smiles_start_idx:smiles_start_idx + pe_length] = smiles_token_pe[:pe_length]
        # Create labels with prompt masked
        labels = full_encoding['input_ids'][0].clone()
        labels[:prompt_length] = -100  # Mask prompt tokens

        return {
            'input_ids': full_encoding['input_ids'][0],
            'graph_pes': full_token_pe,
            'labels': labels,
        }

# First calculate max lengths
max_full_length = calculate_max_lengths('hiv_train.jsonl')

# Then create dataloaders with those lengths
def create_dataloaders(train_path='hiv_train.jsonl', val_path='hiv_val.jsonl',
                      batch_size=8, tokenizer=None, max_full_length=max_full_length):
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

    train_dataset = HIVDataset(train_path, tokenizer, max_full_length)
    val_dataset = HIVDataset(val_path, tokenizer, max_full_length)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, tokenizer

# Usage
train_loader, val_loader, tokenizer = create_dataloaders(batch_size=8)

# Test one batch
sample_batch = next(iter(train_loader))
print(f"Input IDs shape: {sample_batch['input_ids'].shape}")
print(f"Graph PEs shape: {sample_batch['graph_pes'].shape}")
print(f"Labels shape: {sample_batch['labels'].shape}")

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Input IDs shape: torch.Size([8, 498])
Graph PEs shape: torch.Size([8, 498, 30])
Labels shape: torch.Size([8, 498])


In [12]:
# Look at first example in batch
i = 0
print(f"Input text (decoded): {tokenizer.decode(sample_batch['input_ids'][i])}")
print(f"\nNon-zero PE positions: {torch.nonzero(sample_batch['graph_pes'][i].sum(dim=1)).squeeze().tolist()}")
print(f"\nLabel tokens that aren't -100: {sample_batch['labels'][i][sample_batch['labels'][i] != -100].tolist()}")

Input text (decoded): Classify the following molecule based on its HIV activity. Respond with '1' if the molecule shows HIV activity, or '0' if it does not.
SMILES: COc1cc2c(cc1O)C1(C(=O)c3ccc4c(c3C1O)OCO4)N(C)CC2
Answer: 0<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftex

In [13]:
# %pip install peft

In [46]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer,get_linear_schedule_with_warmup
from peft import LoraConfig, get_peft_model, TaskType
from typing import Optional

class HIVPEModule(nn.Module):
    def __init__(self, pe_dim=30, embed_dim=768):
        super().__init__()
        self.pe_projection = nn.Linear(pe_dim, embed_dim)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(0.7)
        # Better initialization for stability
        nn.init.xavier_uniform_(self.pe_projection.weight, gain=0.1)
        nn.init.zeros_(self.pe_projection.bias)

        # Use a more reasonable scale that works with fp16
        self.scale = 0.1  #

    def forward(self, embeddings, graph_pes):
        if graph_pes is not None:
            # Ensure proper dtype
            graph_pes = graph_pes.to(embeddings.dtype)
            projected_pes = self.activation(self.pe_projection(graph_pes))
            projected_pes = self.dropout(projected_pes)
            # Use learnable scale parameter for better gradient flow
            return embeddings + self.scale * projected_pes
        return embeddings

class PythiaWithPE(nn.Module):
    """Wrapper that adds PE to Pythia model"""
    def __init__(self, base_model, pe_dim=30):
        super().__init__()
        self.base_model = base_model
        self.embed_dim = base_model.gpt_neox.embed_in.embedding_dim
        self.pe_module = HIVPEModule(pe_dim=pe_dim, embed_dim=self.embed_dim)

    def forward(
        self,
        input_ids=None,
        graph_pes=None,
        attention_mask=None,
        labels=None,
        **kwargs
    ):
        # Get embeddings
        inputs_embeds = self.base_model.gpt_neox.embed_in(input_ids)

        # Add PE with proper dtype handling
        if graph_pes is not None:
            inputs_embeds = self.pe_module(inputs_embeds, graph_pes)

        # Forward through the model with embeddings
        outputs = self.base_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels,
            **kwargs
        )

        return outputs

    def generate(self, input_ids, graph_pes=None, **kwargs):
        """Generate with PE support"""
        # Get embeddings with PE for the prompt
        inputs_embeds = self.base_model.gpt_neox.embed_in(input_ids)
        if graph_pes is not None:
            inputs_embeds = self.pe_module(inputs_embeds, graph_pes)

        # Generate using embeddings
        return self.base_model.generate(
            inputs_embeds=inputs_embeds,
            **kwargs
        )

def setup_model_with_pe_and_lora(model_name="EleutherAI/pythia-160m", pe_dim=30):
    """
    Setup Pythia with custom PE and LoRA using PEFT
    """
    print(f"Loading {model_name}...")

    # Load base model - use float32 for training stability
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float32,  # Use float32 for stability
        device_map="auto"
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Wrap with PE
    model = PythiaWithPE(base_model, pe_dim=pe_dim)

    # Configure LoRA with conservative settings
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=4,
        lora_alpha=8,
        lora_dropout=0.3,  # Increased dropout
        target_modules=[
            "query_key_value",
            "dense_h_to_4h",
            "dense_4h_to_h"
        ],
        # Initialize LoRA weights with smaller values
        init_lora_weights="gaussian"
    )

    # Apply LoRA to the base model
    model.base_model = get_peft_model(model.base_model, lora_config)

    # Print trainable parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"✓ Total parameters: {total_params:,}")
    print(f"✓ Trainable parameters: {trainable_params:,}")
    print(f"✓ Trainable %: {100 * trainable_params / total_params:.2f}%")

    return model, tokenizer

def train_with_focal_loss(
    model,
    train_loader,
    val_loader=None,
    epochs=3,
    learning_rate=1e-5,
    device="cuda",
    tokenizer=None,save_checkpoint_at_batch=2000,  # Add this parameter
    checkpoint_path="checkpoint_batch_2000.pt"  # Add this parameter
):
    """Training loop with Focal Loss for extreme class imbalance"""

    model = model.to(device)

    # Focal loss parameters
    alpha = 0.965  # Weight for class 0 (should roughly match class distribution)
    gamma = 5.0   # Focusing parameter - increase this for more extreme imbalance

    gradient_accumulation_steps = 2
    validation_frequency = 100  # Check validation every 2000 batches

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=0.3,
        eps=1e-6
    )
    num_training_steps = (len(train_loader) * epochs) // gradient_accumulation_steps
    scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=100,
    num_training_steps=num_training_steps)
    token_0 = 470  # " 0"
    token_1 = 337  # " 1"

    scaler = torch.amp.GradScaler('cuda') if device == "cuda" else None

    def focal_loss(logits, labels, alpha=0.1, gamma=5.0):
        """
        Focal loss implementation
        alpha: weight for class 0 (majority class)
        1-alpha: weight for class 1 (minority class)
        """
        # Get class probabilities
        probs = torch.softmax(logits, dim=-1)

        # Get probability of true class
        ce_loss = nn.functional.cross_entropy(logits, labels, reduction='none')
        pt = torch.exp(-ce_loss)

        # Apply class weights
        # For token_0 (majority), use alpha; for token_1 (minority), use 1-alpha
        alpha_t = torch.where(labels == token_0,  1 - alpha,alpha)

        # Focal term: (1 - pt)^gamma reduces loss for well-classified examples
        focal_weight = alpha_t * (1 - pt) ** gamma

        # Final focal loss
        focal_loss = focal_weight * ce_loss

        return focal_loss.mean()

    def weighted_cross_entropy(logits, labels):
      """Simple weighted CE for extreme class imbalance"""
      # Create class weights
      class_weights = torch.zeros(logits.size(-1), device=logits.device)
      class_weights[token_0] = 0.001  # Weight for majority class (3.5%)
      class_weights[token_1] = 0.999  # Weight for minority class (96.5%)

      return F.cross_entropy(logits, labels, weight=class_weights)

    def run_validation(model, val_loader, max_batches=None):
        """Run validation and return metrics"""
        model.eval()
        val_loss = 0
        val_class_0_correct = 0
        val_class_1_correct = 0
        val_class_0_total = 0
        val_class_1_total = 0
        val_batches = 0

        with torch.no_grad():
            for i, batch in enumerate(val_loader):
                if max_batches and i >= max_batches:
                    break

                input_ids = batch['input_ids'].to(device)
                graph_pes = batch['graph_pes'].to(device)
                labels = batch['labels'].to(device)

                outputs = model(input_ids=input_ids, graph_pes=graph_pes)
                logits = outputs.logits

                mask = ((labels == token_0) | (labels == token_1)) & (labels != -100)

                if mask.any():
                    logits_flat = logits.view(-1, logits.size(-1))
                    labels_flat = labels.view(-1)
                    mask_flat = mask.view(-1)

                    relevant_logits = logits_flat[mask_flat]
                    relevant_labels = labels_flat[mask_flat]

                    #loss = focal_loss(relevant_logits, relevant_labels, alpha=alpha, gamma=gamma)
                    loss = weighted_cross_entropy(relevant_logits, relevant_labels)
                    loss = torch.clamp(loss, min=0.001)  # Prevent loss from going below 0.001
                    loss = loss / gradient_accumulation_steps
                    if not torch.isnan(loss):
                        val_loss += loss.item()
                        val_batches += 1

                        preds = relevant_logits.argmax(dim=-1)
                        class_0_mask = relevant_labels == token_0
                        class_1_mask = relevant_labels == token_1

                        if class_0_mask.any():
                            val_class_0_correct += (preds[class_0_mask] == relevant_labels[class_0_mask]).sum().item()
                            val_class_0_total += class_0_mask.sum().item()

                        if class_1_mask.any():
                            val_class_1_correct += (preds[class_1_mask] == relevant_labels[class_1_mask]).sum().item()
                            val_class_1_total += class_1_mask.sum().item()

        model.train()
        return {
            'loss': val_loss / val_batches if val_batches > 0 else float('inf'),
            'class_0_acc': val_class_0_correct / val_class_0_total if val_class_0_total > 0 else 0,
            'class_1_acc': val_class_1_correct / val_class_1_total if val_class_1_total > 0 else 0,
            'class_0_total': val_class_0_total,
            'class_1_total': val_class_1_total
        }

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        valid_batches = 0
        class_0_correct = 0
        class_1_correct = 0
        class_0_total = 0
        class_1_total = 0

        # Track predictions distribution
        pred_0_count = 0
        pred_1_count = 0

        for batch_idx, batch in enumerate(train_loader):
            input_ids = batch['input_ids'].to(device)
            graph_pes = batch['graph_pes'].to(device)
            labels = batch['labels'].to(device)
            if batch_idx == save_checkpoint_at_batch:
                print(f"\nSaving checkpoint at batch {batch_idx}...")
                checkpoint = {
                    'batch': batch_idx,
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'scaler_state_dict': scaler.state_dict() if scaler else None,
                    'train_loss': total_loss / valid_batches if valid_batches > 0 else 0,
                    'class_0_acc': class_0_correct / class_0_total if class_0_total > 0 else 0,
                    'class_1_acc': class_1_correct / class_1_total if class_1_total > 0 else 0,
                }
                torch.save(checkpoint, checkpoint_path)
                print(f"Checkpoint saved to {checkpoint_path}")
            if scaler is not None:
                with torch.amp.autocast('cuda'):
                    outputs = model(input_ids=input_ids, graph_pes=graph_pes)
                    logits = outputs.logits

                    if torch.isnan(logits).any():
                        print(f"NaN detected at batch {batch_idx}, skipping...")
                        optimizer.zero_grad()
                        continue

                    # Your mask for 0/1 tokens
                    mask = ((labels == token_0) | (labels == token_1)) & (labels != -100)

                    if mask.any():
                        logits_flat = logits.view(-1, logits.size(-1))
                        labels_flat = labels.view(-1)
                        mask_flat = mask.view(-1)

                        relevant_logits = logits_flat[mask_flat]
                        relevant_labels = labels_flat[mask_flat]

                        # Apply focal loss
                        #loss = focal_loss(relevant_logits, relevant_labels, alpha=alpha, gamma=gamma)

                        # loss = loss / gradient_accumulation_steps
                        loss = weighted_cross_entropy(relevant_logits, relevant_labels)
                        # # Add L2 regularization on logits to prevent extreme predictions
                        logit_reg = 0.001 * (relevant_logits ** 2).mean()
                        loss = loss + logit_reg

                        loss = loss / gradient_accumulation_steps
                        # loss = loss + logit_reg / gradient_accumulation_steps

                        # Track predictions for monitoring
                        with torch.no_grad():
                            preds = relevant_logits.argmax(dim=-1)
                            pred_0_count += (preds == token_0).sum().item()
                            pred_1_count += (preds == token_1).sum().item()
                            # ADD THIS DIAGNOSTIC CODE:
                            if batch_idx % 50 == 0:  # Every 50 batches
                                # Get unique predicted tokens
                                unique_preds, counts = torch.unique(preds, return_counts=True)
                                print(f"\n  === Token Prediction Diagnostic (Batch {batch_idx}) ===")
                                print(f"  Expected tokens: {token_0} (' 0'), {token_1} (' 1')")
                                print(f"  Actually predicted tokens: {dict(zip(unique_preds.tolist(), counts.tolist()))}")

                                # Decode the predicted tokens to see what they are
                                print(f"  Decoded predictions:")
                                for token_id, count in zip(unique_preds.tolist(), counts.tolist()):
                                    decoded = tokenizer.decode([token_id])
                                    print(f"    Token {token_id}: '{decoded}' (count: {count})")

                                # Sample a few predictions to see what's happening
                                sample_size = min(5, len(preds))
                                sample_indices = torch.randperm(len(preds))[:sample_size]
                                print(f"\n  Sample predictions vs labels:")
                                for idx in sample_indices:
                                    pred_token = preds[idx].item()
                                    label_token = relevant_labels[idx].item()
                                    print(f"    Predicted: {pred_token} ('{tokenizer.decode([pred_token])}') | Label: {label_token} ('{tokenizer.decode([label_token])}')")

                                # Check logits for tokens 470 and 337
                                print(f"\n  Logit values for key tokens (first 3 samples):")
                                for i in range(min(3, len(relevant_logits))):
                                    logit_470 = relevant_logits[i, token_0].item()
                                    logit_337 = relevant_logits[i, token_1].item()
                                    print(f"    Sample {i}: token_470=' 0' logit={logit_470:.3f}, token_337=' 1' logit={logit_337:.3f}")
                            class_0_mask = relevant_labels == token_0
                            class_1_mask = relevant_labels == token_1

                            if class_0_mask.any():
                                class_0_correct += (preds[class_0_mask] == relevant_labels[class_0_mask]).sum().item()
                                class_0_total += class_0_mask.sum().item()

                            if class_1_mask.any():
                                class_1_correct += (preds[class_1_mask] == relevant_labels[class_1_mask]).sum().item()
                                class_1_total += class_1_mask.sum().item()
                    else:
                        continue

                scaler.scale(loss).backward()

                if (batch_idx + 1) % gradient_accumulation_steps == 0:
                    scaler.unscale_(optimizer)

                    # Log gradient norm
                    if batch_idx % 50 == 0:
                        total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                        print(f"  Gradient norm: {total_norm:.4f}")
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    scheduler.step()

            if mask.any() and not torch.isnan(loss):
                total_loss += loss.item() * gradient_accumulation_steps
                valid_batches += 1

                if batch_idx % 10 == 0:
                    print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}, "
                          f"Loss: {loss.item() * gradient_accumulation_steps:.4f}")

                    # Print detailed statistics every 50 batches
                    if batch_idx > 0 and batch_idx % 50 == 0:
                        print(f"  Class 0 accuracy: {class_0_correct/class_0_total:.2%} ({class_0_correct}/{class_0_total})")
                        if class_1_total > 0:
                            print(f"  Class 1 accuracy: {class_1_correct/class_1_total:.2%} ({class_1_correct}/{class_1_total})")
                        else:
                            print(f"  Class 1 accuracy: No class 1 samples yet")
                        print(f"  Predictions distribution: {pred_0_count} zeros, {pred_1_count} ones")

                        # Reset prediction counters
                        pred_0_count = 0
                        pred_1_count = 0

                # Periodic validation check
                if val_loader is not None and batch_idx > 0 and batch_idx % validation_frequency == 0:
                    print(f"\n--- Validation Check at Batch {batch_idx} ---")
                    val_metrics = run_validation(model, val_loader, max_batches=100)  # Quick check on 100 batches
                    print(f"Validation Loss: {val_metrics['loss']:.4f}")
                    print(f"Validation Class 0 accuracy: {val_metrics['class_0_acc']:.2%} "
                          f"({int(val_metrics['class_0_acc'] * val_metrics['class_0_total'])}/{val_metrics['class_0_total']})")
                    if val_metrics['class_1_total'] > 0:
                        print(f"Validation Class 1 accuracy: {val_metrics['class_1_acc']:.2%} "
                              f"({int(val_metrics['class_1_acc'] * val_metrics['class_1_total'])}/{val_metrics['class_1_total']})")
                    print("--- End Validation Check ---\n")

        # End of epoch summary
        print(f"\n{'='*50}")
        print(f"Epoch {epoch+1} Summary:")
        print(f"Average Loss: {total_loss/valid_batches:.4f}")
        if class_0_total > 0:
            print(f"Class 0 accuracy: {class_0_correct/class_0_total:.2%} ({class_0_correct}/{class_0_total})")
        if class_1_total > 0:
            print(f"Class 1 accuracy: {class_1_correct/class_1_total:.2%} ({class_1_correct}/{class_1_total})")
        else:
            print("No class 1 samples in this epoch!")
        print(f"{'='*50}\n")

        # Full validation at epoch end
        if val_loader is not None:
            print("Running full validation...")
            val_metrics = run_validation(model, val_loader)  # Full validation
            print(f"Full Validation Results:")
            print(f"  Loss: {val_metrics['loss']:.4f}")
            print(f"  Class 0 accuracy: {val_metrics['class_0_acc']:.2%}")
            if val_metrics['class_1_total'] > 0:
                print(f"  Class 1 accuracy: {val_metrics['class_1_acc']:.2%}")
            print()

    return model

# Updated usage code
if __name__ == "__main__":
    # Setup model and tokenizer
    model, tokenizer = setup_model_with_pe_and_lora()

    # Use the correct device and dtype setup
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    # Don't convert to half precision manually - let autocast handle it
    # model.pe_module = model.pe_module.half()  # Remove this line

    print("Model setup complete and ready for training!")

Loading EleutherAI/pythia-160m...
✓ Total parameters: 162,862,848
✓ Trainable parameters: 539,904
✓ Trainable %: 0.33%
Model setup complete and ready for training!


In [47]:
import gc

# Clear GPU memory
torch.cuda.empty_cache()
gc.collect()
# Setup model and tokenizer
model, tokenizer = setup_model_with_pe_and_lora()

# Convert model to device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)


# Use smaller batch size to ensure minority class representation
train_loader, val_loader, _ = create_dataloaders(
    batch_size=10,  # Reduced to ensure we see minority class
    tokenizer=tokenizer
)

# Check if we actually see class 1 in the data
print("\nChecking class 1 presence in first 10 batches:")
class_1_count = 0
for i, batch in enumerate(train_loader):
    if i >= 10:
        break
    labels = batch['labels']
    if (labels == 337).any():
        class_1_count += 1
print(f"Found class 1 in {class_1_count}/10 batches")

# Train with focal loss (use the new function)
trained_model = train_with_focal_loss(  # Changed from train_with_pe
    model,
    train_loader,
    val_loader,
    epochs=3,
    learning_rate=1e-4,
    device=device,
    tokenizer=tokenizer,save_checkpoint_at_batch=2000,checkpoint_path="checkpoint_batch_2000.pt"
)

Loading EleutherAI/pythia-160m...
✓ Total parameters: 162,862,848
✓ Trainable parameters: 539,904
✓ Trainable %: 0.33%

Checking class 1 presence in first 10 batches:


  eigenvalues, eigenvectors = eigsh(laplacian, k=k, which='SM')


Found class 1 in 2/10 batches

  === Token Prediction Diagnostic (Batch 0) ===
  Expected tokens: 470 (' 0'), 337 (' 1')
  Actually predicted tokens: {187: 10}
  Decoded predictions:
    Token 187: '
' (count: 10)

  Sample predictions vs labels:
    Predicted: 187 ('
') | Label: 470 (' 0')
    Predicted: 187 ('
') | Label: 470 (' 0')
    Predicted: 187 ('
') | Label: 470 (' 0')
    Predicted: 187 ('
') | Label: 470 (' 0')
    Predicted: 187 ('
') | Label: 470 (' 0')

  Logit values for key tokens (first 3 samples):
    Sample 0: token_470=' 0' logit=830.500, token_337=' 1' logit=831.000
    Sample 1: token_470=' 0' logit=832.000, token_337=' 1' logit=832.500
    Sample 2: token_470=' 0' logit=831.500, token_337=' 1' logit=831.000
Epoch 1/3, Batch 0/3290, Loss: 669.2621
Epoch 1/3, Batch 10/3290, Loss: 669.4377
Epoch 1/3, Batch 20/3290, Loss: 668.3129
Epoch 1/3, Batch 30/3290, Loss: 661.7687
Epoch 1/3, Batch 40/3290, Loss: 648.8789

  === Token Prediction Diagnostic (Batch 50) ===
  Exp

KeyboardInterrupt: 