In [1]:
import numpy as np
from datasets import load_dataset
import json
import matplotlib.pyplot as plt
from transformers import ViTFeatureExtractor
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)
from torch.utils.data import DataLoader
import torch
from transformers import ViTForImageClassification
from transformers import file_utils
from transformers import TrainingArguments, Trainer
from datasets import load_metric
import time
from transformers import EarlyStoppingCallback, IntervalStrategy



In [14]:
val_ds = load_dataset("imagefolder", data_dir="UrbanSound8K/noisy_mel", split="train")

Resolving data files:   0%|          | 0/8735 [00:00<?, ?it/s]

Using custom data configuration default-a0b70fe5fa2aa308
Found cached dataset imagefolder (/home/suvivars/.cache/huggingface/datasets/imagefolder/default-a0b70fe5fa2aa308/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)


In [15]:
with open('UrbanSound8K/audio_mel/label2id.json') as f:
    label2id = json.load(f)

with open('UrbanSound8K/audio_mel/id2label.json') as f:
    id2label = json.load(f)

id2label = {int(key):value for key,value in id2label.items()}
id2label

{0: 'air_conditioner',
 1: 'car_horn',
 2: 'children_playing',
 3: 'dog_bark',
 4: 'drilling',
 5: 'engine_idling',
 6: 'gun_shot',
 7: 'jackhammer',
 8: 'siren',
 9: 'street_music'}

In [16]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

In [17]:
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)

In [18]:
_val_transforms = Compose(
        [
            Resize(tuple(feature_extractor.size.values())),
            CenterCrop(tuple(feature_extractor.size.values())),
            ToTensor(),
            normalize,
        ]
    )

In [19]:
def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

In [20]:
val_ds.set_transform(val_transforms)

In [26]:
metric_name = "accuracy"

args = TrainingArguments(
    f"train-UrbanSounds8k",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=4,
    num_train_epochs=50,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
    remove_unused_columns=False,
)



In [24]:
metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

  """Entry point for launching an IPython kernel.


In [27]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=5)]
)


In [28]:
model = ViTForImageClassification.from_pretrained('models/Noisy-UrbanSounds8k-EarlyStopping',
                                                  num_labels=10,
                                                  id2label=id2label,
                                                  label2id=label2id)

loading configuration file models/Noisy-UrbanSounds8k-EarlyStopping/config.json
Model config ViTConfig {
  "_name_or_path": "google/vit-base-patch16-224-in21k",
  "architectures": [
    "ViTForImageClassification"
  ],
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "id2label": {
    "0": "air_conditioner",
    "1": "car_horn",
    "2": "children_playing",
    "3": "dog_bark",
    "4": "drilling",
    "5": "engine_idling",
    "6": "gun_shot",
    "7": "jackhammer",
    "8": "siren",
    "9": "street_music"
  },
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "air_conditioner": 0,
    "car_horn": 1,
    "children_playing": 2,
    "dog_bark": 3,
    "drilling": 4,
    "engine_idling": 5,
    "gun_shot": 6,
    "jackhammer": 7,
    "siren": 8,
    "street_music": 9
  },
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12

In [29]:
outputs = trainer.predict(val_ds)

***** Running Prediction *****
  Num examples = 8732
  Batch size = 4


In [37]:
print(outputs.metrics)

{'test_loss': 0.2682583540678024, 'test_accuracy': 0.6516720109940448, 'test_runtime': 128.9412, 'test_samples_per_second': 67.721, 'test_steps_per_second': 16.93}
