# Acknowledgment
This notebook uses some part of Inference code created by [MUHAMMAD AHMED](https://www.kaggle.com/muhammad4hmed) and their notebook **[HMS] Inference - ViT on Spectrograms**. I would like to express my gratitude for their valuable contribution to the Kaggle community.

Link to the original notebook: [Original Notebook Link](https://www.kaggle.com/code/muhammad4hmed/hms-inference-vit-on-spectrograms)

# Experiments Details
* Model: ViT, Epoch: 1, BS: 32 , CV:1.607 , LB: 0.96
* Model: ViT, Epoch: 15, BS: 32 , CV: 1.3912, LB: 

[Training Notebook](https://www.kaggle.com/dky7376/gpu-train-hms-vit-pipeline)

# Inference

In [None]:
%%writefile inference.py

import os
import argparse
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
from transformers import ViTModel
import matplotlib as cm
from torch.nn.functional import softmax, one_hot
cmap = cm.colormaps["viridis"]

# Define constants
num_classes = 6
batch_size = 32

# Define dataset class
class SpectrogramDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        specto_id = self.data.loc[idx, 'spectrogram_id']
        specto_path = f'/kaggle/input/hms-harmful-brain-activity-classification/test_spectrograms/{specto_id}.parquet'
        specto = pd.read_parquet(specto_path)
        spectrogram = Image.fromarray((cmap(specto) * 255).astype(np.uint8))
        if self.transform:
            spectrogram = self.transform(spectrogram)[:3, :, :]
        return spectrogram

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.ToTensor(),  # Convert to PyTorch tensor
])

# Define model class
class ViTClassifier(torch.nn.Module):
    def __init__(self, num_classes=1000):
        super().__init__()
        self.vit = ViTModel.from_pretrained("/kaggle/input/google-vit-base-patch16-224-in21k")
        self.classifier = torch.nn.Linear(self.vit.config.hidden_size, num_classes)

    def forward(self, images):
        output = self.vit(images)
        output = self.classifier(output.last_hidden_state[:, 0]) 
        output = softmax(output, dim = 1)
        return output

# Inference function
def inference(model, test_loader, device):
    model.eval()
    out = []
    with torch.no_grad():
        for spectrograms in tqdm(test_loader, desc="Inference"):
            spectrograms = spectrograms.to(device)
            outputs = model(spectrograms)
            outputs = outputs.detach().cpu().numpy()
            out.append(outputs)
    outputs = np.vstack(out)
    return outputs
    

# Define main function for training
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--test_path", type=str)
    parser.add_argument("--trained_model", type=str)
    args = parser.parse_args()
    
    # Load the dataset
    test_data = pd.read_csv(args.test_path)

    # Check if GPU is available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    dataset = SpectrogramDataset(test_data, transform=transform)

    # Create data loaders
    test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Load the trained model
    model = ViTClassifier(num_classes).to(device)
    model.load_state_dict(torch.load(args.trained_model))

    outputs = inference(model, test_loader, device)
    submission = test_data[['eeg_id']]
    submission['seizure_vote'] = outputs[:, 0]
    submission['lpd_vote'] = outputs[:, 5]
    submission['gpd_vote'] = outputs[:, 1]
    submission['lrda_vote'] = outputs[:, 2]
    submission['grda_vote'] = outputs[:, 4]
    submission['other_vote'] = outputs[:, 3]
    submission.to_csv('submission.csv', index = False)

if __name__ == "__main__":
    main()


In [None]:
TEST_PATH = "/kaggle/input/hms-harmful-brain-activity-classification/test.csv"
TRAINED_MODEL = '/kaggle/input/gpu-train-hms-vit-pipeline/trained_hms_vit_model_v4.pt'

In [None]:
# utilize both t4 gpus
!accelerate launch --num_processes 2  inference.py \
  --test_path $TEST_PATH \
  --trained_model $TRAINED_MODEL

In [None]:
import pandas as pd

sub = pd.read_csv("/kaggle/working/submission.csv")
display(sub.head())