In [None]:
import torch
import torchaudio
import pyarrow.parquet as pq
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms
import numpy as np
from transformers import ViTModel
import pandas as pd
import torchvision.models as models
import os
from tqdm import tqdm
from torch.nn.functional import one_hot, softmax
import matplotlib.cm as cm
from PIL import Image

In [None]:
train = pd.read_csv('/kaggle/input/hms-harmful-brain-activity-classification/train.csv')
test = pd.read_csv('/kaggle/input/hms-harmful-brain-activity-classification/test.csv')
classes = train['expert_consensus'].unique()
mapping = {
    c:i for i, c in enumerate(classes)
}
num_classes = classes.shape[0]

In [None]:
cmap = cm.get_cmap("viridis")

In [None]:
class SpectrogramDataset(Dataset):
    def __init__(self, data_folder, transform=None):
        self.data_folder = data_folder
        self.file_paths = [os.path.join(data_folder, f) for f in os.listdir(data_folder)]
        self.transform = transform
#         self.remove_empty()
    
    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        label = os.path.basename(file_path).split("_")[1].split(".")[0]  # Extract label from filename
        spectrogram = pd.read_parquet(file_path).drop('time',axis=1).values  # Load parquet file
        spectrogram = Image.fromarray((cmap(spectrogram) * 255).astype(np.uint8))
        if self.transform:
            spectrogram = self.transform(spectrogram)[:3, :, :]
            
        return spectrogram, mapping[label]

In [None]:
data_folder = "/kaggle/input/hms-data-prepare-separate-spectogram/separate_spectogram/"

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
#     transforms.CenterCrop(224),  # Center crop to maintain aspect ratio
    transforms.ToTensor(),  # Convert to PyTorch tensor
#     transforms.Normalize(mean=[0.485, 0.456, 0.406, 0.406], std=[0.229, 0.224, 0.225, 0.225])  # Normalize (optional)
])

dataset = SpectrogramDataset(data_folder, transform=transform)
train_length = int(0.8 * len(dataset))
val_length = len(dataset) - train_length

print(f'train size: {train_length} \nval size: {val_length}')

train_set, val_set = torch.utils.data.random_split(dataset, [train_length, val_length])

train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32)

In [None]:
class ViTClassifier(torch.nn.Module):
    def __init__(self, num_classes=1000):
        super().__init__()
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.classifier = torch.nn.Linear(self.vit.config.hidden_size, num_classes)

    def forward(self, images):
        output = self.vit(images)
        output = self.classifier(output.last_hidden_state[:, 0]) 
        output = softmax(output, dim = 1)
        return output

In [None]:
device = 'cuda'
model = ViTClassifier(num_classes).to(device)

In [None]:
loss_fn = nn.KLDivLoss(reduction="batchmean")
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    pbar = tqdm(train_loader)
    for spectrograms, labels in pbar:
        labels_onehot = one_hot(labels, num_classes=num_classes).float().to(device)
        spectrograms = spectrograms.to(device)
        optimizer.zero_grad()
        outputs = model(spectrograms)
        loss = loss_fn(outputs.log(), labels_onehot)
        loss.backward()
        optimizer.step()
        pbar.set_description(f'train loss {loss}')

    # Evaluate on validation set
    with torch.no_grad():
        correct = 0
        total = 0
        val_loss = 0
        pbar = tqdm(val_loader)
        for spectrograms, labels in pbar:
            labels = labels.to(device)
            labels_onehot = one_hot(labels, num_classes=num_classes).float()
            spectrograms = spectrograms.to(device)
            outputs = model(spectrograms)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss = loss_fn(outputs.log(), labels_onehot).item()
            val_loss += loss
            pbar.set_description(f'val loss {loss}')

        accuracy = 100 * correct / total
        kl_divergence = val_loss / len(val_loader)

        print("Validation Accuracy:", accuracy)
        print("Validation KL Divergence:", kl_divergence)


# Save the trained model
torch.save(model.state_dict(), "trained_model.pt")