# Acknowledgment
This notebook uses the preprocessed dataset created by [MUHAMMAD AHMED](https://www.kaggle.com/muhammad4hmed) and their notebook **titled [HMS] - Data Prepare - Separate Spectogram**. I would like to express my gratitude for their valuable contribution to the Kaggle community.

Link to the original notebook: [Original Notebook Link](https://www.kaggle.com/code/muhammad4hmed/hms-data-prepare-separate-spectogram)

# Experiments Details
* Model: ViT, Epoch: 1, BS: 32 , CV: , LB: 0.96
* Model: ViT, Epoch: 15, BS: 32 , CV: , LB: 

[Inference Notebook](https://www.kaggle.com/dky7376/gpu-infer-hms-vit-pipeline)

# Load the dataset

In [None]:
import pandas as pd

train = pd.read_csv('/kaggle/input/hms-harmful-brain-activity-classification/train.csv')
test = pd.read_csv('/kaggle/input/hms-harmful-brain-activity-classification/test.csv')
classes = train['expert_consensus'].unique()
mapping = {
    c:i for i, c in enumerate(classes)
}
num_classes = classes.shape[0]

print(mapping)
print(num_classes)

In [None]:
display(train.head())

In [None]:
display(test.head())

# Train the Model

In [None]:
%%writefile train.py

import os
import argparse
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
from transformers import ViTModel
import matplotlib as cm
from torch.nn.functional import softmax, one_hot
cmap = cm.colormaps["viridis"]

# Define constants
num_classes = 6
num_epochs = 15
batch_size = 32

# Define dataset class
class SpectrogramDataset(Dataset):
    def __init__(self, file_paths, mapping, transform=None):
        self.file_paths = file_paths
        self.transform = transform
        self.mapping = mapping

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        label = os.path.basename(file_path).split("_")[1].split(".")[0]  # Extract label from filename
        spectrogram = pd.read_parquet(file_path).drop('time', axis=1).values  # Load parquet file
        spectrogram = Image.fromarray((cmap(spectrogram) * 255).astype(np.uint8))
        if self.transform:
            spectrogram = self.transform(spectrogram)[:3, :, :]
        return spectrogram, self.mapping[label]


# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.ToTensor(),  # Convert to PyTorch tensor
])

# Define model class
class ViTClassifier(torch.nn.Module):
    def __init__(self, num_classes=1000):
        super().__init__()
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.classifier = torch.nn.Linear(self.vit.config.hidden_size, num_classes)

    def forward(self, images):
        output = self.vit(images)
        output = self.classifier(output.last_hidden_state[:, 0]) 
        output = softmax(output, dim = 1)
        return output

# Training function
def train(model, train_loader, optimizer, scheduler, loss_fn, device):
    model.train()
    total_loss = 0.0

    for spectrograms, labels in tqdm(train_loader, desc="Training"):
        labels_onehot = one_hot(labels, num_classes=num_classes).float().to(device)
        spectrograms = spectrograms.to(device)
        optimizer.zero_grad()
        outputs = model(spectrograms)
        loss = loss_fn(outputs.log(), labels_onehot)
        loss.backward()
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

# Validation function
def validate(model, val_loader, loss_fn, device):
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0

    with torch.no_grad():
        for spectrograms, labels in tqdm(val_loader, desc="Validation"):
            labels = labels.to(device)
            labels_onehot = one_hot(labels, num_classes=num_classes).float()
            spectrograms = spectrograms.to(device)
            outputs = model(spectrograms)
            _, predicted = torch.max(outputs.data, 1)    
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss = loss_fn(outputs.log(), labels_onehot).item()
            val_loss += loss
    accuracy = 100 * correct / total
    kl_divergence = val_loss / len(val_loader)
    return accuracy, kl_divergence

# Define main function for training
def main():
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--train_path", type=str)
    parser.add_argument("--data_folder", type=str)
    
    args = parser.parse_args()
    
    # Load the dataset
    train_data = pd.read_csv(args.train_path)
    classes = train_data['expert_consensus'].unique()
    mapping = {
        c:i for i, c in enumerate(classes)
    }
    num_classes = classes.shape[0]

    # Check if GPU is available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load dataset
    file_paths = [os.path.join(args.data_folder, f) for f in os.listdir(args.data_folder)]
    dataset = SpectrogramDataset(file_paths, mapping, transform=transform)

    # Split dataset into training and validation sets
    train_length = int(0.8 * len(dataset))
    val_length = len(dataset) - train_length
    train_set, val_set = torch.utils.data.random_split(dataset, [train_length, val_length])

    # Create data loaders
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size)

    # Instantiate model, loss function, and optimizer
    model = ViTClassifier(num_classes).to(device)

    loss_fn = nn.KLDivLoss(reduction="batchmean")
    optimizer = optim.Adam(model.parameters(), lr=5e-4)
    
    # Set up the learning rate scheduler
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.0001)

    # Training loop
    for epoch in range(num_epochs):
        train_loss = train(model, train_loader, optimizer, scheduler, loss_fn, device)
        val_accuracy, val_kl_divergence = validate(model, val_loader, loss_fn, device)

        print(f"Epoch {epoch + 1}/{num_epochs} => "
              f"Train Loss: {train_loss:.4f}, "
              f"Validation Accuracy: {val_accuracy:.2f}%, "
              f"Validation KL Divergence: {val_kl_divergence:.4f}")

    # Save the trained model
    torch.save(model.state_dict(), "trained_hms_vit_model_v4.pt")

if __name__ == "__main__":
    main()

In [None]:
DATA_FOLDER = "/kaggle/input/hms-data-prepare-separate-spectogram/separate_spectogram/"
TRAIN_DATASET = "/kaggle/input/hms-harmful-brain-activity-classification/train.csv"

In [None]:
!accelerate launch --num_processes 2  train.py \
  --train_path $TRAIN_DATASET \
  --data_folder $DATA_FOLDER