In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
import pandas as pd

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
LEVEL_TO_INDEX = {
    'A2_0': 0,
    'B1_1': 1,
    'B1_2': 2,
    'B2_0': 3
}

def load_texts_and_labels(folder_path):
    texts = []
    labels = []
    count = 0  # Initialize a counter
    for filename in os.listdir(folder_path):
        if filename.endswith('.txt'):
            # Extract Level from filename (e.g., SM_CHN_PTJ1_146_B1_1.txt)
            level = '_'.join(filename.split('_')[-2:]).replace('.txt', '')  # Extract B1_1
            if level not in LEVEL_TO_INDEX:
                print(f"Skipping file {filename} due to unknown level {level}")
                continue
            label_index = LEVEL_TO_INDEX[level]

            # Read File Content
            with open(os.path.join(folder_path, filename), 'r', encoding='utf-8') as f:
                texts.append(f.read())
                labels.append(label_index)
                count += 1  # Increment the counter
                if count >= 4000:  # Break if 4000 samples are loaded
                    break

    return texts, labels

# Example Usage
#folder_path = '/content/drive/MyDrive/CULI_project/SM_0_Unclassified_Unmerged'
folder_path = '/content/drive/MyDrive/SM_0_Unclassified_Unmerged'
texts, labels = load_texts_and_labels(folder_path)

print(f"Loaded {len(texts)} samples.")
print("Example:", texts[0][:100], "Label Index:", labels[0])

In [None]:
total_length = 0
for text in texts:
    total_length += len(text)

mean_length = total_length / len(texts)
print(f"Mean length of texts: {mean_length:.2f}")

In [None]:
text_series = pd.Series(texts)
text_lengths = text_series.apply(len)
mean_length_pd = text_lengths.mean()

In [None]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        # Removed max_length from here as we're not doing fixed padding/truncation

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # Tokenize without fixed max_length padding/truncation
        encoding = self.tokenizer(
            self.texts[idx],
            # padding='max_length', # Removed
            truncation=True, # Still good to truncate if very long for model limits
            # max_length=self.max_length, # Removed
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'label': torch.tensor(self.labels[idx], dtype=torch.long)
        }

# Custom collate function to pad sequences within a batch
def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    labels = torch.stack([item['label'] for item in batch])

    # Pad input_ids and attention_mask to the maximum length in the current batch
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0) # Assuming padding value 0 for attention mask

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'label': labels
    }


tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Pass only texts, labels, and tokenizer to the dataset
dataset = TextDataset(texts, labels, tokenizer)

# Use the custom collate_fn with the DataLoader
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

# Code to check the implementation (shapes will now vary based on batch max length)
print("Checking shapes of individual items from the dataset:")
for i in range(min(5, len(dataset))): # Check first 5 items or fewer if dataset is small
    item = dataset[i]
    print(f"Item {i}:")
    print("  Input IDs shape:", item['input_ids'].shape)
    print("  Attention Mask shape:", item['attention_mask'].shape)
    print("  Label shape:", item['label'].shape)

print("\nChecking shapes of a batch from the dataloader:")
for batch in dataloader:
    print("Input IDs shape:", batch['input_ids'].shape)
    print("Attention Mask shape:", batch['attention_mask'].shape)
    print("Label shape:", batch['label'])
    break # Just check one batch

In [None]:
from sklearn.model_selection import train_test_split

# Split into training and a temporary set (validation + test)
# Use test_size=0.2 to leave 80% for training
train_texts, temp_texts, train_labels, temp_labels = train_test_split(
    texts,
    labels,
    test_size=0.10,  # 20% for validation + test
    random_state=42,
    stratify=labels  # Ensures equal class distribution
)

# Split the temporary set into validation and test sets
# Use test_size=0.5 to split the 20% temp set into 10% validation and 10% test
val_texts, test_texts, val_labels, test_labels = train_test_split(
    temp_texts,
    temp_labels,
    test_size=0.5,  # Half of 20% = 10% test
    random_state=42,
    stratify=temp_labels  # Maintains balance in val/test
)

print(f"Training samples: {len(train_texts)}")
print(f"Validation samples: {len(val_texts)}")
print(f"Test samples: {len(test_texts)}")

In [None]:
train_dataset = TextDataset(train_texts, train_labels, tokenizer)
val_dataset = TextDataset(val_texts, val_labels, tokenizer)
test_dataset = TextDataset(test_texts, test_labels, tokenizer)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)

print(f"Number of batches in training dataloader: {len(train_dataloader)}")
print(f"Number of batches in validation dataloader: {len(val_dataloader)}")
print(f"Number of batches in test dataloader: {len(test_dataloader)}")

In [None]:
class PrototypicalNet(nn.Module):
    def __init__(self, num_classes=4, embed_dim=256, num_prototypes=3, similarity='cosine'):
        super(PrototypicalNet, self).__init__()
        self.encoder = BertModel.from_pretrained('bert-base-uncased')
        self.encoder_dim = self.encoder.config.hidden_size  # BERT output size (768)
        self.num_classes = num_classes
        self.num_prototypes = num_prototypes
        self.embed_dim = embed_dim  # Final embedding dimension (256)
        self.similarity = similarity

        # MLP to reduce BERT output to embed_dim
        self.mlp = nn.Sequential(
            nn.Linear(self.encoder.config.hidden_size, embed_dim),
            nn.GELU(),
            nn.LayerNorm(embed_dim)
        )

        self.prototypes = nn.Parameter(torch.randn(num_classes, num_prototypes, 256))
        if similarity == 'cosine':
            self.s = nn.Parameter(torch.tensor(10.0))
            self.b = nn.Parameter(torch.tensor(0.0))
            self.temp = nn.Parameter(torch.tensor(1.0))  # Temperature scaling

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        x = outputs.last_hidden_state.mean(dim=1)  # [B, encoder_dim]
        x = self.mlp(x)  # [B, embed_dim]

        if self.similarity == 'cosine':
            x_norm = F.normalize(x, p=2, dim=-1)  # [B, embed_dim]
            p_norm = F.normalize(self.prototypes, p=2, dim=-1)  # [C, K, embed_dim]
            x_exp = x_norm.unsqueeze(1).unsqueeze(2)  # [B, 1, 1, embed_dim]
            p_exp = p_norm.unsqueeze(0)               # [1, C, K, embed_dim]
            sim = (x_exp * p_exp).sum(dim=-1)         # [B, C, K]
            sim = sim.mean(dim=2)                     # [B, C]
            logits = (self.s * sim + self.b) / self.temp  # [B, C]

        elif self.similarity == 'euclidean':
            dist = ((x.unsqueeze(1).unsqueeze(2) - self.prototypes.unsqueeze(0)) ** 2).sum(-1)  # [B, C, K]
            sim = -dist.mean(dim=2)  # [B, C]
            logits = sim

        return logits, x

    def init_prototypes(self, dataloader, labels):
        """Initialize prototypes using KMeans clustering per class"""
        from sklearn.cluster import KMeans
        embeddings = []

        with torch.no_grad():
            for batch in dataloader:
                input_ids = batch['input_ids'].to(next(self.parameters()).device)
                attention_mask = batch['attention_mask'].to(next(self.parameters()).device)
                outputs = self.encoder(input_ids, attention_mask)
                x = outputs.last_hidden_state.mean(1)
                x = self.mlp(x)
                embeddings.append(x.cpu())

        embeddings = torch.cat(embeddings)
        labels_tensor = torch.tensor(labels)

        for c in range(self.num_classes):
            class_embeddings = embeddings[labels_tensor == c]
            if len(class_embeddings) == 0:
                print(f"Warning: No samples found for class {c}. Initializing prototypes randomly.")
                self.prototypes.data[c] = torch.randn(self.num_prototypes, self.embed_dim).to(self.prototypes.device)
                continue

            n_clusters = min(self.num_prototypes, len(class_embeddings))
            kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
            kmeans.fit(class_embeddings.numpy())

            prototypes_c = torch.zeros(self.num_prototypes, self.embed_dim)
            prototypes_c[:n_clusters] = torch.tensor(kmeans.cluster_centers_, dtype=prototypes_c.dtype)

            self.prototypes.data[c] = prototypes_c.to(self.prototypes.device)

In [None]:
def initialize_prototypes(model, dataloader, num_classes=4, num_prototypes=3):
    device = next(model.parameters()).device
    class_embeddings = [[] for _ in range(num_classes)]

    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Initializing Prototypes"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels_batch = batch['label'].to(device)

            outputs = model.encoder(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state.mean(dim=1)  # [B, 768]
            embeddings = model.mlp(embeddings)  # <-- Project to [B, 256]

            for emb, label in zip(embeddings, labels_batch):
                class_embeddings[label.item()].append(emb)

    prototype_list = []
    for class_embeds in class_embeddings:
        class_embeds = torch.stack(class_embeds)
        if len(class_embeds) < num_prototypes:
            mean_embed = class_embeds.mean(dim=0, keepdim=True)
            padded = mean_embed.repeat(num_prototypes - len(class_embeds), 1)
            proto_class = torch.cat([class_embeds, padded], dim=0)
        else:
            indices = torch.randperm(len(class_embeds))[:num_prototypes]
            proto_class = class_embeds[indices]
        prototype_list.append(proto_class)

    model.prototypes.data = torch.stack(prototype_list).to(device)
    print(f"Prototypes shape after init: {model.prototypes.shape}")  # Should be [C, K, 256]

In [None]:
def compute_loss_weights(labels, num_classes, alpha=0.5, device='cpu'):
    # Count samples per class
    label_count = torch.zeros(num_classes)
    for label in labels:
        label_count[label] += 1

    label_count = label_count.float()
    label_count_pow = label_count.pow(alpha)
    lw_weights = label_count_pow / label_count_pow.sum()
    lw_weights = lw_weights / label_count  # The "re-weighted" loss weights
    lw_weights = lw_weights.to(device)

    return lw_weights

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PrototypicalNet(num_classes=4, embed_dim=256).to(device)

# Initialize prototypes using the training data
# Note: After splitting data, you should use the train_dataloader here
initialize_prototypes(model, train_dataloader, num_classes=4)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)


lw_weights = compute_loss_weights(train_labels, num_classes=4, alpha=0.5, device=device)

criterion = nn.CrossEntropyLoss(weight=lw_weights)


train_losses = [] # List to store training loss per epoch


In [None]:
from sklearn.metrics import accuracy_score

# Initialize lists to track metrics
train_losses = []
val_losses = []
val_accuracies = []

for epoch in range(15):
    # --- Training Phase ---
    model.train()
    total_train_loss = 0
    all_train_preds = []
    all_train_labels = []

    # Use the training dataloader
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1} Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels_batch = batch['label'].to(device)

        optimizer.zero_grad()
        logits, _ = model(input_ids, attention_mask)

        # Handle potential batch size mismatches
        batch_size = logits.shape[0]
        labels_batch = labels_batch[:batch_size]

        loss = criterion(logits, labels_batch)
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

        # Store predictions and labels for accuracy calculation
        preds = torch.argmax(logits, dim=1)
        all_train_preds.extend(preds.cpu().numpy())
        all_train_labels.extend(labels_batch.cpu().numpy())

    # Calculate training metrics
    avg_train_loss = total_train_loss / len(train_dataloader)
    train_accuracy = accuracy_score(all_train_labels, all_train_preds) * 100
    train_losses.append(avg_train_loss)

    # --- Validation Phase ---
    model.eval()
    total_val_loss = 0
    all_val_preds = []
    all_val_labels = []

    with torch.no_grad():
        # Use the validation dataloader
        for batch in tqdm(val_dataloader, desc=f"Epoch {epoch+1} Validation"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels_batch = batch['label'].to(device)

            logits, _ = model(input_ids, attention_mask)

            batch_size = logits.shape[0]
            labels_batch = labels_batch[:batch_size]

            loss = criterion(logits, labels_batch)
            total_val_loss += loss.item()

            preds = torch.argmax(logits, dim=1)
            all_val_preds.extend(preds.cpu().numpy())
            all_val_labels.extend(labels_batch.cpu().numpy())

    # Calculate validation metrics
    avg_val_loss = total_val_loss / len(val_dataloader)
    val_accuracy = accuracy_score(all_val_labels, all_val_preds) * 100
    val_losses.append(avg_val_loss)
    val_accuracies.append(val_accuracy)

    # Print epoch summary
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Training Loss: {avg_train_loss:.4f} | Training Acc: {train_accuracy:.2f}%")
    print(f"  Validation Loss: {avg_val_loss:.4f} | Validation Acc: {val_accuracy:.2f}%\n")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Training Complete Summary
# Assuming val_accuracies list is populated during training
if val_accuracies:
    print(f"\nBest Val Accuracy: {max(val_accuracies):.2f}% (Epoch {np.argmax(val_accuracies)+1})")
else:
    print("\nNo validation accuracy data available to display best accuracy.")


# Plotting
plt.figure(figsize=(12, 4))

# Plotting Loss
ax1 = plt.subplot(1, 2, 1)
ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Val Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()

# Plotting Accuracy
ax2 = plt.subplot(1, 2, 2)
ax2.plot(val_accuracies, label='Val Acc', color='green')
ax2.set_title('Validation Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.legend()

plt.tight_layout() # Adjust layout to prevent overlapping titles/labels
plt.show()

# Evaluation Function
def evaluate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels_batch = batch['label'].to(device)

            logits, _ = model(input_ids, attention_mask)
            preds = torch.argmax(logits, 1)

            # Ensure labels_batch has the same size as the actual batch size
            batch_size = input_ids.shape[0]
            labels_batch = labels_batch[:batch_size]

            correct += (preds == labels_batch).sum().item()
            total += labels_batch.size(0)
    print(f"Accuracy: {correct/total*100:.2f}%")

# Test Evaluation
print("\n--- Test Results ---")
# Assuming test_dataloader is created after data splitting
if 'test_dataloader' in locals():
    evaluate(model, test_dataloader)
else:
    print("Test dataloader not found. Please perform data splitting and create test_dataloader first.")