# <img src="https://img.icons8.com/bubbles/50/000000/mind-map.png" style="height:50px;display:inline"> ECE 046211 - Technion - Deep Learning
---

## Project
---

In [None]:
%pip install datasets
%pip install transformers
%pip install accelerate
%pip install evaluate
%pip install gc

In [1]:
from datasets import load_dataset, Audio,DatasetDict
import numpy as np
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification, TrainingArguments, Trainer,AutoConfig
import evaluate
import torch
import gc

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
gtzan = load_dataset("marsyas/gtzan", "all",trust_remote_code=True)
train_rest = gtzan["train"].train_test_split(seed=42, shuffle=True, test_size=0.2)
val_test = train_rest["test"].train_test_split(seed=42, shuffle=True, test_size=0.5)
gtzan=DatasetDict({'train': train_rest["train"],'val': val_test["train"],'test': val_test["test"]})

In [3]:
id2label_fn = gtzan["train"].features["genre"].int2str

In [4]:
model_id = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
sampling_rate = feature_extractor.sampling_rate
gtzan = gtzan.cast_column("audio", Audio(sampling_rate=sampling_rate))

In [5]:
max_duration = 30.0


def preprocess_function(examples):
    audio_arrays = [x["array"] for x in examples["audio"]]
    inputs = feature_extractor(
        audio_arrays,
        sampling_rate=feature_extractor.sampling_rate,
        max_length=int(feature_extractor.sampling_rate * max_duration),
        truncation=True,
    )
    return inputs

In [6]:
gtzan_encoded = gtzan.map(
    preprocess_function,
    remove_columns=["audio", "file"],
    batched=True,
    batch_size=100,
    num_proc=1,
)

In [7]:
gtzan_encoded = gtzan_encoded.rename_column("genre", "label")
id2label = {
    i: id2label_fn(i)
    for i in range(len(gtzan_encoded["train"].features["label"].names))
}
label2id = {v: k for k, v in id2label.items()}

In [8]:
metric = evaluate.load("accuracy")


def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

In [9]:
gc.collect()
torch.cuda.empty_cache()

num_labels = len(id2label)
model = AutoModelForAudioClassification.from_pretrained(model_id)
model.generation_config=AutoConfig.from_pretrained(model_id)
model.generation_config.id2label=id2label
model.generation_config.label2id=label2id
model.requires_grad_(False)
model.num_labels=num_labels
model.classifier.dense= torch.nn.Linear(model.classifier.dense.in_features,num_labels)

model_name = model_id.split("/")[-1]
batch_size = 64
gradient_accumulation_steps = 1
num_train_epochs = 50

training_args = TrainingArguments(
    f"{model_name}-finetuned-gtzan",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_train_epochs,
    warmup_ratio=0.1,
    logging_steps=5,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=True,
    push_to_hub=False,
    save_total_limit=5,
)
trainer = Trainer(
    model,
    training_args,
    train_dataset=gtzan_encoded["train"].with_format("torch"),
    eval_dataset=gtzan_encoded["val"].with_format("torch"),
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
)

trainer.train()

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
  context_layer = torch.nn.functional.scaled_dot_product_attention(
  1%|          | 5/650 [00:39<1:01:38,  5.73s/it]

{'loss': 2.613, 'grad_norm': 6.291839122772217, 'learning_rate': 3.846153846153847e-06, 'epoch': 0.38}


  2%|▏         | 10/650 [00:56<39:07,  3.67s/it] 

{'loss': 2.6684, 'grad_norm': 7.667351245880127, 'learning_rate': 7.692307692307694e-06, 'epoch': 0.77}


                                                
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 2.694891929626465, 'eval_accuracy': 0.03, 'eval_runtime': 4.1459, 'eval_samples_per_second': 24.12, 'eval_steps_per_second': 0.482, 'epoch': 1.0}


  2%|▏         | 15/650 [01:14<43:08,  4.08s/it]

{'loss': 2.6196, 'grad_norm': 6.438978672027588, 'learning_rate': 1.153846153846154e-05, 'epoch': 1.15}


  3%|▎         | 20/650 [01:27<31:10,  2.97s/it]

{'loss': 2.6004, 'grad_norm': 5.6341705322265625, 'learning_rate': 1.5384615384615387e-05, 'epoch': 1.54}


  4%|▍         | 25/650 [01:51<40:26,  3.88s/it]

{'loss': 2.5761, 'grad_norm': 6.729978561401367, 'learning_rate': 1.923076923076923e-05, 'epoch': 1.92}


                                                
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 2.6344237327575684, 'eval_accuracy': 0.02, 'eval_runtime': 3.8581, 'eval_samples_per_second': 25.92, 'eval_steps_per_second': 0.518, 'epoch': 2.0}


  5%|▍         | 30/650 [02:20<44:22,  4.29s/it]  

{'loss': 2.5843, 'grad_norm': 6.178511619567871, 'learning_rate': 2.307692307692308e-05, 'epoch': 2.31}


  5%|▌         | 35/650 [02:30<25:05,  2.45s/it]

{'loss': 2.5435, 'grad_norm': 6.059666156768799, 'learning_rate': 2.6923076923076923e-05, 'epoch': 2.69}


                                                
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 2.539423942565918, 'eval_accuracy': 0.03, 'eval_runtime': 3.4444, 'eval_samples_per_second': 29.033, 'eval_steps_per_second': 0.581, 'epoch': 3.0}


  6%|▌         | 40/650 [02:49<50:28,  4.96s/it]

{'loss': 2.4701, 'grad_norm': 5.869167804718018, 'learning_rate': 3.0769230769230774e-05, 'epoch': 3.08}


  7%|▋         | 45/650 [03:00<26:07,  2.59s/it]

{'loss': 2.4611, 'grad_norm': 5.67149543762207, 'learning_rate': 3.461538461538462e-05, 'epoch': 3.46}


  8%|▊         | 50/650 [03:11<22:27,  2.25s/it]

{'loss': 2.4321, 'grad_norm': 5.688538074493408, 'learning_rate': 3.846153846153846e-05, 'epoch': 3.85}


                                                
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 2.4192090034484863, 'eval_accuracy': 0.03, 'eval_runtime': 3.4824, 'eval_samples_per_second': 28.716, 'eval_steps_per_second': 0.574, 'epoch': 4.0}


  8%|▊         | 55/650 [03:25<27:23,  2.76s/it]

{'loss': 2.3729, 'grad_norm': 5.458479881286621, 'learning_rate': 4.230769230769231e-05, 'epoch': 4.23}


  9%|▉         | 60/650 [03:36<22:33,  2.29s/it]

{'loss': 2.2937, 'grad_norm': 5.0902934074401855, 'learning_rate': 4.615384615384616e-05, 'epoch': 4.62}


 10%|█         | 65/650 [03:46<17:10,  1.76s/it]

{'loss': 2.2877, 'grad_norm': 7.276853561401367, 'learning_rate': 5e-05, 'epoch': 5.0}


                                                
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 2.2736523151397705, 'eval_accuracy': 0.11, 'eval_runtime': 3.5358, 'eval_samples_per_second': 28.282, 'eval_steps_per_second': 0.566, 'epoch': 5.0}


 11%|█         | 70/650 [04:01<24:03,  2.49s/it]

{'loss': 2.2494, 'grad_norm': 5.589336395263672, 'learning_rate': 4.9572649572649575e-05, 'epoch': 5.38}


 12%|█▏        | 75/650 [04:13<21:48,  2.28s/it]

{'loss': 2.1594, 'grad_norm': 4.683980464935303, 'learning_rate': 4.9145299145299147e-05, 'epoch': 5.77}


                                                
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 2.1354784965515137, 'eval_accuracy': 0.19, 'eval_runtime': 3.5508, 'eval_samples_per_second': 28.162, 'eval_steps_per_second': 0.563, 'epoch': 6.0}


 12%|█▏        | 80/650 [04:27<28:34,  3.01s/it]

{'loss': 2.086, 'grad_norm': 5.12574577331543, 'learning_rate': 4.871794871794872e-05, 'epoch': 6.15}


 13%|█▎        | 85/650 [04:38<22:27,  2.38s/it]

{'loss': 2.0193, 'grad_norm': 4.634902477264404, 'learning_rate': 4.829059829059829e-05, 'epoch': 6.54}


 14%|█▍        | 90/650 [04:49<20:28,  2.19s/it]

{'loss': 2.0378, 'grad_norm': 5.0698089599609375, 'learning_rate': 4.786324786324787e-05, 'epoch': 6.92}


                                                
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 2.0089452266693115, 'eval_accuracy': 0.27, 'eval_runtime': 3.5761, 'eval_samples_per_second': 27.964, 'eval_steps_per_second': 0.559, 'epoch': 7.0}


 15%|█▍        | 95/650 [05:04<24:39,  2.67s/it]

{'loss': 1.9791, 'grad_norm': 4.990859031677246, 'learning_rate': 4.7435897435897435e-05, 'epoch': 7.31}


 15%|█▌        | 100/650 [05:15<21:23,  2.33s/it]

{'loss': 1.8974, 'grad_norm': 4.6273393630981445, 'learning_rate': 4.700854700854701e-05, 'epoch': 7.69}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.895595669746399, 'eval_accuracy': 0.35, 'eval_runtime': 3.641, 'eval_samples_per_second': 27.465, 'eval_steps_per_second': 0.549, 'epoch': 8.0}


 16%|█▌        | 105/650 [05:30<31:12,  3.44s/it]

{'loss': 1.8854, 'grad_norm': 5.007087707519531, 'learning_rate': 4.6581196581196586e-05, 'epoch': 8.08}


 17%|█▋        | 110/650 [05:41<22:05,  2.45s/it]

{'loss': 1.8031, 'grad_norm': 4.152040958404541, 'learning_rate': 4.615384615384616e-05, 'epoch': 8.46}


 18%|█▊        | 115/650 [05:53<20:29,  2.30s/it]

{'loss': 1.7962, 'grad_norm': 4.919305324554443, 'learning_rate': 4.572649572649573e-05, 'epoch': 8.85}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.790917992591858, 'eval_accuracy': 0.46, 'eval_runtime': 3.6396, 'eval_samples_per_second': 27.476, 'eval_steps_per_second': 0.55, 'epoch': 9.0}


 18%|█▊        | 120/650 [06:07<24:57,  2.83s/it]

{'loss': 1.7663, 'grad_norm': 4.294926166534424, 'learning_rate': 4.52991452991453e-05, 'epoch': 9.23}


 19%|█▉        | 125/650 [06:18<20:43,  2.37s/it]

{'loss': 1.7139, 'grad_norm': 4.754371643066406, 'learning_rate': 4.4871794871794874e-05, 'epoch': 9.62}


 20%|██        | 130/650 [06:28<15:40,  1.81s/it]

{'loss': 1.6752, 'grad_norm': 4.734320640563965, 'learning_rate': 4.4444444444444447e-05, 'epoch': 10.0}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.7004687786102295, 'eval_accuracy': 0.54, 'eval_runtime': 3.5859, 'eval_samples_per_second': 27.887, 'eval_steps_per_second': 0.558, 'epoch': 10.0}


 21%|██        | 135/650 [06:44<21:52,  2.55s/it]

{'loss': 1.658, 'grad_norm': 4.713266849517822, 'learning_rate': 4.401709401709402e-05, 'epoch': 10.38}


 22%|██▏       | 140/650 [06:56<19:48,  2.33s/it]

{'loss': 1.5675, 'grad_norm': 4.956681728363037, 'learning_rate': 4.358974358974359e-05, 'epoch': 10.77}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.6175636053085327, 'eval_accuracy': 0.56, 'eval_runtime': 3.5756, 'eval_samples_per_second': 27.967, 'eval_steps_per_second': 0.559, 'epoch': 11.0}


 22%|██▏       | 145/650 [07:10<25:22,  3.02s/it]

{'loss': 1.5784, 'grad_norm': 4.341971397399902, 'learning_rate': 4.316239316239317e-05, 'epoch': 11.15}


 23%|██▎       | 150/650 [07:21<20:00,  2.40s/it]

{'loss': 1.5843, 'grad_norm': 3.7563412189483643, 'learning_rate': 4.2735042735042735e-05, 'epoch': 11.54}


 24%|██▍       | 155/650 [07:33<18:20,  2.22s/it]

{'loss': 1.4897, 'grad_norm': 4.849435329437256, 'learning_rate': 4.230769230769231e-05, 'epoch': 11.92}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.539951205253601, 'eval_accuracy': 0.62, 'eval_runtime': 3.6276, 'eval_samples_per_second': 27.566, 'eval_steps_per_second': 0.551, 'epoch': 12.0}


 25%|██▍       | 160/650 [07:47<21:32,  2.64s/it]

{'loss': 1.4405, 'grad_norm': 3.848163366317749, 'learning_rate': 4.1880341880341886e-05, 'epoch': 12.31}


 25%|██▌       | 165/650 [07:59<19:05,  2.36s/it]

{'loss': 1.466, 'grad_norm': 3.576035737991333, 'learning_rate': 4.145299145299146e-05, 'epoch': 12.69}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.4689353704452515, 'eval_accuracy': 0.67, 'eval_runtime': 3.7357, 'eval_samples_per_second': 26.769, 'eval_steps_per_second': 0.535, 'epoch': 13.0}


 26%|██▌       | 170/650 [08:13<26:54,  3.36s/it]

{'loss': 1.3945, 'grad_norm': 3.8588831424713135, 'learning_rate': 4.1025641025641023e-05, 'epoch': 13.08}


 27%|██▋       | 175/650 [08:25<19:44,  2.49s/it]

{'loss': 1.3904, 'grad_norm': 3.9706549644470215, 'learning_rate': 4.05982905982906e-05, 'epoch': 13.46}


 28%|██▊       | 180/650 [08:36<18:44,  2.39s/it]

{'loss': 1.3595, 'grad_norm': 3.9564316272735596, 'learning_rate': 4.0170940170940174e-05, 'epoch': 13.85}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.4018356800079346, 'eval_accuracy': 0.67, 'eval_runtime': 3.8019, 'eval_samples_per_second': 26.302, 'eval_steps_per_second': 0.526, 'epoch': 14.0}


 28%|██▊       | 185/650 [08:51<22:35,  2.92s/it]

{'loss': 1.4012, 'grad_norm': 4.916412353515625, 'learning_rate': 3.974358974358974e-05, 'epoch': 14.23}


 29%|██▉       | 190/650 [09:03<18:54,  2.47s/it]

{'loss': 1.2995, 'grad_norm': 4.207294464111328, 'learning_rate': 3.931623931623932e-05, 'epoch': 14.62}


 30%|███       | 195/650 [09:14<14:26,  1.90s/it]

{'loss': 1.2713, 'grad_norm': 4.59133243560791, 'learning_rate': 3.888888888888889e-05, 'epoch': 15.0}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.3423629999160767, 'eval_accuracy': 0.69, 'eval_runtime': 3.7939, 'eval_samples_per_second': 26.358, 'eval_steps_per_second': 0.527, 'epoch': 15.0}


 31%|███       | 200/650 [09:30<19:49,  2.64s/it]

{'loss': 1.2695, 'grad_norm': 3.681856632232666, 'learning_rate': 3.846153846153846e-05, 'epoch': 15.38}


 32%|███▏      | 205/650 [09:42<18:08,  2.45s/it]

{'loss': 1.2307, 'grad_norm': 4.653223037719727, 'learning_rate': 3.8034188034188035e-05, 'epoch': 15.77}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.2904760837554932, 'eval_accuracy': 0.72, 'eval_runtime': 3.8012, 'eval_samples_per_second': 26.308, 'eval_steps_per_second': 0.526, 'epoch': 16.0}


 32%|███▏      | 210/650 [09:57<23:00,  3.14s/it]

{'loss': 1.247, 'grad_norm': 3.750153064727783, 'learning_rate': 3.760683760683761e-05, 'epoch': 16.15}


 33%|███▎      | 215/650 [10:09<18:05,  2.49s/it]

{'loss': 1.21, 'grad_norm': 3.309917688369751, 'learning_rate': 3.717948717948718e-05, 'epoch': 16.54}


 34%|███▍      | 220/650 [10:21<16:42,  2.33s/it]

{'loss': 1.1681, 'grad_norm': 4.039535045623779, 'learning_rate': 3.675213675213676e-05, 'epoch': 16.92}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.2416309118270874, 'eval_accuracy': 0.74, 'eval_runtime': 3.7708, 'eval_samples_per_second': 26.52, 'eval_steps_per_second': 0.53, 'epoch': 17.0}


 35%|███▍      | 225/650 [10:36<19:13,  2.71s/it]

{'loss': 1.2152, 'grad_norm': 3.4147865772247314, 'learning_rate': 3.6324786324786323e-05, 'epoch': 17.31}


 35%|███▌      | 230/650 [10:47<16:47,  2.40s/it]

{'loss': 1.1424, 'grad_norm': 3.458848476409912, 'learning_rate': 3.58974358974359e-05, 'epoch': 17.69}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.1966943740844727, 'eval_accuracy': 0.77, 'eval_runtime': 3.7733, 'eval_samples_per_second': 26.502, 'eval_steps_per_second': 0.53, 'epoch': 18.0}


 36%|███▌      | 235/650 [11:02<23:37,  3.42s/it]

{'loss': 1.0985, 'grad_norm': 3.254791259765625, 'learning_rate': 3.5470085470085474e-05, 'epoch': 18.08}


 37%|███▋      | 240/650 [11:13<16:59,  2.49s/it]

{'loss': 1.1526, 'grad_norm': 3.053083658218384, 'learning_rate': 3.504273504273504e-05, 'epoch': 18.46}


 38%|███▊      | 245/650 [11:25<16:03,  2.38s/it]

{'loss': 1.0789, 'grad_norm': 4.436598300933838, 'learning_rate': 3.461538461538462e-05, 'epoch': 18.85}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.1563037633895874, 'eval_accuracy': 0.76, 'eval_runtime': 3.7772, 'eval_samples_per_second': 26.475, 'eval_steps_per_second': 0.529, 'epoch': 19.0}


 38%|███▊      | 250/650 [11:40<19:23,  2.91s/it]

{'loss': 1.1043, 'grad_norm': 3.4851343631744385, 'learning_rate': 3.418803418803419e-05, 'epoch': 19.23}


 39%|███▉      | 255/650 [11:52<16:22,  2.49s/it]

{'loss': 1.0242, 'grad_norm': 2.927934169769287, 'learning_rate': 3.376068376068376e-05, 'epoch': 19.62}


 40%|████      | 260/650 [12:02<12:22,  1.90s/it]

{'loss': 1.0924, 'grad_norm': 4.21661376953125, 'learning_rate': 3.3333333333333335e-05, 'epoch': 20.0}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.1225171089172363, 'eval_accuracy': 0.77, 'eval_runtime': 3.8028, 'eval_samples_per_second': 26.296, 'eval_steps_per_second': 0.526, 'epoch': 20.0}


 41%|████      | 265/650 [12:19<17:21,  2.70s/it]

{'loss': 1.0131, 'grad_norm': 3.6923739910125732, 'learning_rate': 3.290598290598291e-05, 'epoch': 20.38}


 42%|████▏     | 270/650 [12:32<15:40,  2.48s/it]

{'loss': 1.058, 'grad_norm': 2.935326337814331, 'learning_rate': 3.247863247863248e-05, 'epoch': 20.77}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.087888240814209, 'eval_accuracy': 0.77, 'eval_runtime': 3.8203, 'eval_samples_per_second': 26.176, 'eval_steps_per_second': 0.524, 'epoch': 21.0}


 42%|████▏     | 275/650 [12:47<19:54,  3.19s/it]

{'loss': 0.9929, 'grad_norm': 3.3613181114196777, 'learning_rate': 3.205128205128206e-05, 'epoch': 21.15}


 43%|████▎     | 280/650 [12:59<15:36,  2.53s/it]

{'loss': 1.0268, 'grad_norm': 3.0184504985809326, 'learning_rate': 3.162393162393162e-05, 'epoch': 21.54}


 44%|████▍     | 285/650 [13:10<14:02,  2.31s/it]

{'loss': 0.9986, 'grad_norm': 3.4185497760772705, 'learning_rate': 3.1196581196581195e-05, 'epoch': 21.92}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.0589306354522705, 'eval_accuracy': 0.77, 'eval_runtime': 3.6912, 'eval_samples_per_second': 27.091, 'eval_steps_per_second': 0.542, 'epoch': 22.0}


 45%|████▍     | 290/650 [13:25<16:13,  2.70s/it]

{'loss': 0.9369, 'grad_norm': 3.193279981613159, 'learning_rate': 3.0769230769230774e-05, 'epoch': 22.31}


 45%|████▌     | 295/650 [13:37<14:05,  2.38s/it]

{'loss': 0.9603, 'grad_norm': 2.813493013381958, 'learning_rate': 3.034188034188034e-05, 'epoch': 22.69}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.0320227146148682, 'eval_accuracy': 0.77, 'eval_runtime': 3.6674, 'eval_samples_per_second': 27.267, 'eval_steps_per_second': 0.545, 'epoch': 23.0}


 46%|████▌     | 300/650 [13:51<19:28,  3.34s/it]

{'loss': 1.0077, 'grad_norm': 2.798794746398926, 'learning_rate': 2.9914529914529915e-05, 'epoch': 23.08}


 47%|████▋     | 305/650 [14:03<14:11,  2.47s/it]

{'loss': 0.9468, 'grad_norm': 2.934161901473999, 'learning_rate': 2.948717948717949e-05, 'epoch': 23.46}


 48%|████▊     | 310/650 [14:14<13:13,  2.33s/it]

{'loss': 0.9297, 'grad_norm': 3.266263723373413, 'learning_rate': 2.9059829059829063e-05, 'epoch': 23.85}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 1.0048657655715942, 'eval_accuracy': 0.77, 'eval_runtime': 3.622, 'eval_samples_per_second': 27.609, 'eval_steps_per_second': 0.552, 'epoch': 24.0}


 48%|████▊     | 315/650 [14:29<15:44,  2.82s/it]

{'loss': 0.9585, 'grad_norm': 3.0224359035491943, 'learning_rate': 2.863247863247863e-05, 'epoch': 24.23}


 49%|████▉     | 320/650 [14:40<13:09,  2.39s/it]

{'loss': 0.9131, 'grad_norm': 2.9646449089050293, 'learning_rate': 2.8205128205128207e-05, 'epoch': 24.62}


 50%|█████     | 325/650 [14:50<09:53,  1.83s/it]

{'loss': 0.8713, 'grad_norm': 4.573498249053955, 'learning_rate': 2.777777777777778e-05, 'epoch': 25.0}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.9809399247169495, 'eval_accuracy': 0.77, 'eval_runtime': 3.6431, 'eval_samples_per_second': 27.449, 'eval_steps_per_second': 0.549, 'epoch': 25.0}


 51%|█████     | 330/650 [15:06<13:34,  2.55s/it]

{'loss': 0.9124, 'grad_norm': 3.1074745655059814, 'learning_rate': 2.7350427350427355e-05, 'epoch': 25.38}


 52%|█████▏    | 335/650 [15:17<12:17,  2.34s/it]

{'loss': 0.8778, 'grad_norm': 3.0518219470977783, 'learning_rate': 2.6923076923076923e-05, 'epoch': 25.77}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.9585644602775574, 'eval_accuracy': 0.77, 'eval_runtime': 3.6609, 'eval_samples_per_second': 27.316, 'eval_steps_per_second': 0.546, 'epoch': 26.0}


 52%|█████▏    | 340/650 [15:32<15:40,  3.03s/it]

{'loss': 0.8909, 'grad_norm': 3.3651282787323, 'learning_rate': 2.64957264957265e-05, 'epoch': 26.15}


 53%|█████▎    | 345/650 [15:43<12:16,  2.41s/it]

{'loss': 0.8662, 'grad_norm': 2.451727867126465, 'learning_rate': 2.606837606837607e-05, 'epoch': 26.54}


 54%|█████▍    | 350/650 [15:55<11:12,  2.24s/it]

{'loss': 0.8584, 'grad_norm': 2.6536197662353516, 'learning_rate': 2.564102564102564e-05, 'epoch': 26.92}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.9398583769798279, 'eval_accuracy': 0.78, 'eval_runtime': 3.6728, 'eval_samples_per_second': 27.227, 'eval_steps_per_second': 0.545, 'epoch': 27.0}


 55%|█████▍    | 355/650 [16:09<12:59,  2.64s/it]

{'loss': 0.8225, 'grad_norm': 3.016152858734131, 'learning_rate': 2.5213675213675215e-05, 'epoch': 27.31}


 55%|█████▌    | 360/650 [16:21<11:27,  2.37s/it]

{'loss': 0.8928, 'grad_norm': 3.1989176273345947, 'learning_rate': 2.4786324786324787e-05, 'epoch': 27.69}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.9215453863143921, 'eval_accuracy': 0.79, 'eval_runtime': 3.7753, 'eval_samples_per_second': 26.488, 'eval_steps_per_second': 0.53, 'epoch': 28.0}


 56%|█████▌    | 365/650 [16:36<16:31,  3.48s/it]

{'loss': 0.8493, 'grad_norm': 2.797173023223877, 'learning_rate': 2.435897435897436e-05, 'epoch': 28.08}


 57%|█████▋    | 370/650 [16:47<11:43,  2.51s/it]

{'loss': 0.8029, 'grad_norm': 3.4366648197174072, 'learning_rate': 2.3931623931623935e-05, 'epoch': 28.46}


 58%|█████▊    | 375/650 [16:59<10:56,  2.39s/it]

{'loss': 0.8286, 'grad_norm': 2.6573903560638428, 'learning_rate': 2.3504273504273504e-05, 'epoch': 28.85}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.9068164229393005, 'eval_accuracy': 0.8, 'eval_runtime': 3.8051, 'eval_samples_per_second': 26.281, 'eval_steps_per_second': 0.526, 'epoch': 29.0}


 58%|█████▊    | 380/650 [17:14<12:56,  2.88s/it]

{'loss': 0.8333, 'grad_norm': 2.844472646713257, 'learning_rate': 2.307692307692308e-05, 'epoch': 29.23}


 59%|█████▉    | 385/650 [17:25<10:44,  2.43s/it]

{'loss': 0.7853, 'grad_norm': 2.4704959392547607, 'learning_rate': 2.264957264957265e-05, 'epoch': 29.62}


 60%|██████    | 390/650 [17:36<08:13,  1.90s/it]

{'loss': 0.8452, 'grad_norm': 3.941004753112793, 'learning_rate': 2.2222222222222223e-05, 'epoch': 30.0}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.8916430473327637, 'eval_accuracy': 0.8, 'eval_runtime': 3.8062, 'eval_samples_per_second': 26.273, 'eval_steps_per_second': 0.525, 'epoch': 30.0}


 61%|██████    | 395/650 [17:52<11:09,  2.62s/it]

{'loss': 0.8366, 'grad_norm': 2.75935435295105, 'learning_rate': 2.1794871794871795e-05, 'epoch': 30.38}


 62%|██████▏   | 400/650 [18:04<10:12,  2.45s/it]

{'loss': 0.7865, 'grad_norm': 2.6620469093322754, 'learning_rate': 2.1367521367521368e-05, 'epoch': 30.77}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.8777062892913818, 'eval_accuracy': 0.8, 'eval_runtime': 3.8097, 'eval_samples_per_second': 26.249, 'eval_steps_per_second': 0.525, 'epoch': 31.0}


 62%|██████▏   | 405/650 [18:19<12:56,  3.17s/it]

{'loss': 0.765, 'grad_norm': 2.7909374237060547, 'learning_rate': 2.0940170940170943e-05, 'epoch': 31.15}


 63%|██████▎   | 410/650 [18:31<10:11,  2.55s/it]

{'loss': 0.7757, 'grad_norm': 2.6031453609466553, 'learning_rate': 2.0512820512820512e-05, 'epoch': 31.54}


 64%|██████▍   | 415/650 [18:43<09:15,  2.36s/it]

{'loss': 0.7998, 'grad_norm': 2.5851356983184814, 'learning_rate': 2.0085470085470087e-05, 'epoch': 31.92}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.8641698956489563, 'eval_accuracy': 0.8, 'eval_runtime': 3.84, 'eval_samples_per_second': 26.042, 'eval_steps_per_second': 0.521, 'epoch': 32.0}


 65%|██████▍   | 420/650 [18:58<10:40,  2.79s/it]

{'loss': 0.7771, 'grad_norm': 2.5977110862731934, 'learning_rate': 1.965811965811966e-05, 'epoch': 32.31}


 65%|██████▌   | 425/650 [19:10<09:14,  2.46s/it]

{'loss': 0.7681, 'grad_norm': 2.2041707038879395, 'learning_rate': 1.923076923076923e-05, 'epoch': 32.69}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.853424072265625, 'eval_accuracy': 0.8, 'eval_runtime': 3.8482, 'eval_samples_per_second': 25.986, 'eval_steps_per_second': 0.52, 'epoch': 33.0}


 66%|██████▌   | 430/650 [19:26<12:52,  3.51s/it]

{'loss': 0.8461, 'grad_norm': 3.1440184116363525, 'learning_rate': 1.8803418803418804e-05, 'epoch': 33.08}


 67%|██████▋   | 435/650 [19:38<09:24,  2.63s/it]

{'loss': 0.8046, 'grad_norm': 2.6946046352386475, 'learning_rate': 1.837606837606838e-05, 'epoch': 33.46}


 68%|██████▊   | 440/650 [19:50<08:38,  2.47s/it]

{'loss': 0.7001, 'grad_norm': 2.6423027515411377, 'learning_rate': 1.794871794871795e-05, 'epoch': 33.85}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.8431031107902527, 'eval_accuracy': 0.8, 'eval_runtime': 3.8328, 'eval_samples_per_second': 26.091, 'eval_steps_per_second': 0.522, 'epoch': 34.0}


 68%|██████▊   | 445/650 [20:05<10:11,  2.98s/it]

{'loss': 0.7422, 'grad_norm': 2.365510940551758, 'learning_rate': 1.752136752136752e-05, 'epoch': 34.23}


 69%|██████▉   | 450/650 [20:17<08:23,  2.52s/it]

{'loss': 0.793, 'grad_norm': 2.3147659301757812, 'learning_rate': 1.7094017094017095e-05, 'epoch': 34.62}


 70%|███████   | 455/650 [20:28<06:16,  1.93s/it]

{'loss': 0.7391, 'grad_norm': 3.5022287368774414, 'learning_rate': 1.6666666666666667e-05, 'epoch': 35.0}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.8331799507141113, 'eval_accuracy': 0.8, 'eval_runtime': 3.8693, 'eval_samples_per_second': 25.844, 'eval_steps_per_second': 0.517, 'epoch': 35.0}


 71%|███████   | 460/650 [20:45<08:35,  2.71s/it]

{'loss': 0.7514, 'grad_norm': 2.6918962001800537, 'learning_rate': 1.623931623931624e-05, 'epoch': 35.38}


 72%|███████▏  | 465/650 [20:57<07:34,  2.46s/it]

{'loss': 0.7333, 'grad_norm': 1.9856960773468018, 'learning_rate': 1.581196581196581e-05, 'epoch': 35.77}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.8251738548278809, 'eval_accuracy': 0.8, 'eval_runtime': 3.805, 'eval_samples_per_second': 26.282, 'eval_steps_per_second': 0.526, 'epoch': 36.0}


 72%|███████▏  | 470/650 [21:12<09:27,  3.15s/it]

{'loss': 0.758, 'grad_norm': 2.03110933303833, 'learning_rate': 1.5384615384615387e-05, 'epoch': 36.15}


 73%|███████▎  | 475/650 [21:24<07:15,  2.49s/it]

{'loss': 0.6912, 'grad_norm': 2.9480361938476562, 'learning_rate': 1.4957264957264958e-05, 'epoch': 36.54}


 74%|███████▍  | 480/650 [21:35<06:27,  2.28s/it]

{'loss': 0.7678, 'grad_norm': 2.2658097743988037, 'learning_rate': 1.4529914529914531e-05, 'epoch': 36.92}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.8184106945991516, 'eval_accuracy': 0.8, 'eval_runtime': 3.7424, 'eval_samples_per_second': 26.721, 'eval_steps_per_second': 0.534, 'epoch': 37.0}


 75%|███████▍  | 485/650 [21:50<07:24,  2.69s/it]

{'loss': 0.7313, 'grad_norm': 2.123532295227051, 'learning_rate': 1.4102564102564104e-05, 'epoch': 37.31}


 75%|███████▌  | 490/650 [22:02<06:23,  2.39s/it]

{'loss': 0.7775, 'grad_norm': 2.6983389854431152, 'learning_rate': 1.3675213675213677e-05, 'epoch': 37.69}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.8112116456031799, 'eval_accuracy': 0.8, 'eval_runtime': 3.6562, 'eval_samples_per_second': 27.351, 'eval_steps_per_second': 0.547, 'epoch': 38.0}


 76%|███████▌  | 495/650 [22:16<08:35,  3.33s/it]

{'loss': 0.7162, 'grad_norm': 2.470912218093872, 'learning_rate': 1.324786324786325e-05, 'epoch': 38.08}


 77%|███████▋  | 500/650 [22:28<06:11,  2.48s/it]

{'loss': 0.7347, 'grad_norm': 2.6733105182647705, 'learning_rate': 1.282051282051282e-05, 'epoch': 38.46}


 78%|███████▊  | 505/650 [22:39<05:40,  2.35s/it]

{'loss': 0.7102, 'grad_norm': 2.0576467514038086, 'learning_rate': 1.2393162393162394e-05, 'epoch': 38.85}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.8048401474952698, 'eval_accuracy': 0.8, 'eval_runtime': 3.6709, 'eval_samples_per_second': 27.241, 'eval_steps_per_second': 0.545, 'epoch': 39.0}


 78%|███████▊  | 510/650 [22:54<06:36,  2.83s/it]

{'loss': 0.7174, 'grad_norm': 2.518043279647827, 'learning_rate': 1.1965811965811967e-05, 'epoch': 39.23}


 79%|███████▉  | 515/650 [23:05<05:27,  2.42s/it]

{'loss': 0.735, 'grad_norm': 2.728104829788208, 'learning_rate': 1.153846153846154e-05, 'epoch': 39.62}


 80%|████████  | 520/650 [23:15<04:03,  1.87s/it]

{'loss': 0.7247, 'grad_norm': 3.479991912841797, 'learning_rate': 1.1111111111111112e-05, 'epoch': 40.0}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.7995508313179016, 'eval_accuracy': 0.8, 'eval_runtime': 3.7091, 'eval_samples_per_second': 26.961, 'eval_steps_per_second': 0.539, 'epoch': 40.0}


 81%|████████  | 525/650 [23:32<05:20,  2.57s/it]

{'loss': 0.7199, 'grad_norm': 2.1722121238708496, 'learning_rate': 1.0683760683760684e-05, 'epoch': 40.38}


 82%|████████▏ | 530/650 [23:43<04:44,  2.37s/it]

{'loss': 0.7193, 'grad_norm': 2.213534116744995, 'learning_rate': 1.0256410256410256e-05, 'epoch': 40.77}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.7950543165206909, 'eval_accuracy': 0.79, 'eval_runtime': 3.6663, 'eval_samples_per_second': 27.275, 'eval_steps_per_second': 0.546, 'epoch': 41.0}


 82%|████████▏ | 535/650 [23:58<05:52,  3.07s/it]

{'loss': 0.7203, 'grad_norm': 2.22472882270813, 'learning_rate': 9.82905982905983e-06, 'epoch': 41.15}


 83%|████████▎ | 540/650 [24:09<04:28,  2.44s/it]

{'loss': 0.695, 'grad_norm': 2.516953229904175, 'learning_rate': 9.401709401709402e-06, 'epoch': 41.54}


 84%|████████▍ | 545/650 [24:21<03:56,  2.25s/it]

{'loss': 0.6956, 'grad_norm': 1.9854487180709839, 'learning_rate': 8.974358974358976e-06, 'epoch': 41.92}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.7900927066802979, 'eval_accuracy': 0.79, 'eval_runtime': 3.6061, 'eval_samples_per_second': 27.731, 'eval_steps_per_second': 0.555, 'epoch': 42.0}


 85%|████████▍ | 550/650 [24:35<04:24,  2.65s/it]

{'loss': 0.738, 'grad_norm': 2.2502458095550537, 'learning_rate': 8.547008547008548e-06, 'epoch': 42.31}


 85%|████████▌ | 555/650 [24:47<03:45,  2.37s/it]

{'loss': 0.7002, 'grad_norm': 2.3912360668182373, 'learning_rate': 8.11965811965812e-06, 'epoch': 42.69}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.7867462038993835, 'eval_accuracy': 0.79, 'eval_runtime': 3.6303, 'eval_samples_per_second': 27.546, 'eval_steps_per_second': 0.551, 'epoch': 43.0}


 86%|████████▌ | 560/650 [25:01<05:00,  3.34s/it]

{'loss': 0.6818, 'grad_norm': 2.280057191848755, 'learning_rate': 7.692307692307694e-06, 'epoch': 43.08}


 87%|████████▋ | 565/650 [25:13<03:31,  2.49s/it]

{'loss': 0.7123, 'grad_norm': 2.7859208583831787, 'learning_rate': 7.264957264957266e-06, 'epoch': 43.46}


 88%|████████▊ | 570/650 [25:24<03:06,  2.33s/it]

{'loss': 0.6851, 'grad_norm': 2.560289144515991, 'learning_rate': 6.837606837606839e-06, 'epoch': 43.85}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.783780574798584, 'eval_accuracy': 0.8, 'eval_runtime': 3.6103, 'eval_samples_per_second': 27.699, 'eval_steps_per_second': 0.554, 'epoch': 44.0}


 88%|████████▊ | 575/650 [25:39<03:30,  2.80s/it]

{'loss': 0.7413, 'grad_norm': 2.461509943008423, 'learning_rate': 6.41025641025641e-06, 'epoch': 44.23}


 89%|████████▉ | 580/650 [25:50<02:47,  2.39s/it]

{'loss': 0.6668, 'grad_norm': 2.639981746673584, 'learning_rate': 5.982905982905984e-06, 'epoch': 44.62}


 90%|█████████ | 585/650 [26:00<01:59,  1.83s/it]

{'loss': 0.6973, 'grad_norm': 3.364818811416626, 'learning_rate': 5.555555555555556e-06, 'epoch': 45.0}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.781203031539917, 'eval_accuracy': 0.8, 'eval_runtime': 3.7237, 'eval_samples_per_second': 26.855, 'eval_steps_per_second': 0.537, 'epoch': 45.0}


 91%|█████████ | 590/650 [26:16<02:33,  2.55s/it]

{'loss': 0.6863, 'grad_norm': 2.2012434005737305, 'learning_rate': 5.128205128205128e-06, 'epoch': 45.38}


 92%|█████████▏| 595/650 [26:28<02:08,  2.35s/it]

{'loss': 0.6922, 'grad_norm': 2.2365031242370605, 'learning_rate': 4.700854700854701e-06, 'epoch': 45.77}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.7791222929954529, 'eval_accuracy': 0.8, 'eval_runtime': 3.6191, 'eval_samples_per_second': 27.631, 'eval_steps_per_second': 0.553, 'epoch': 46.0}


 92%|█████████▏| 600/650 [26:42<02:31,  3.02s/it]

{'loss': 0.7251, 'grad_norm': 2.317014694213867, 'learning_rate': 4.273504273504274e-06, 'epoch': 46.15}


 93%|█████████▎| 605/650 [26:53<01:49,  2.42s/it]

{'loss': 0.6659, 'grad_norm': 2.245941638946533, 'learning_rate': 3.846153846153847e-06, 'epoch': 46.54}


 94%|█████████▍| 610/650 [27:05<01:30,  2.26s/it]

{'loss': 0.7236, 'grad_norm': 1.8115315437316895, 'learning_rate': 3.4188034188034193e-06, 'epoch': 46.92}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.7770483493804932, 'eval_accuracy': 0.8, 'eval_runtime': 3.6147, 'eval_samples_per_second': 27.665, 'eval_steps_per_second': 0.553, 'epoch': 47.0}


 95%|█████████▍| 615/650 [27:19<01:32,  2.64s/it]

{'loss': 0.682, 'grad_norm': 1.840643286705017, 'learning_rate': 2.991452991452992e-06, 'epoch': 47.31}


 95%|█████████▌| 620/650 [27:31<01:10,  2.36s/it]

{'loss': 0.7313, 'grad_norm': 2.214906692504883, 'learning_rate': 2.564102564102564e-06, 'epoch': 47.69}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.7759350538253784, 'eval_accuracy': 0.8, 'eval_runtime': 3.6262, 'eval_samples_per_second': 27.577, 'eval_steps_per_second': 0.552, 'epoch': 48.0}


 96%|█████████▌| 625/650 [27:45<01:22,  3.31s/it]

{'loss': 0.6433, 'grad_norm': 2.120664358139038, 'learning_rate': 2.136752136752137e-06, 'epoch': 48.08}


 97%|█████████▋| 630/650 [27:56<00:49,  2.48s/it]

{'loss': 0.6817, 'grad_norm': 2.2112364768981934, 'learning_rate': 1.7094017094017097e-06, 'epoch': 48.46}


 98%|█████████▊| 635/650 [28:08<00:35,  2.36s/it]

{'loss': 0.7171, 'grad_norm': 2.4461944103240967, 'learning_rate': 1.282051282051282e-06, 'epoch': 48.85}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.7750914692878723, 'eval_accuracy': 0.8, 'eval_runtime': 3.6245, 'eval_samples_per_second': 27.59, 'eval_steps_per_second': 0.552, 'epoch': 49.0}


 98%|█████████▊| 640/650 [28:22<00:27,  2.80s/it]

{'loss': 0.659, 'grad_norm': 2.5099470615386963, 'learning_rate': 8.547008547008548e-07, 'epoch': 49.23}


 99%|█████████▉| 645/650 [28:34<00:11,  2.36s/it]

{'loss': 0.7302, 'grad_norm': 2.4073848724365234, 'learning_rate': 4.273504273504274e-07, 'epoch': 49.62}


Non-default generation parameters: {'max_length': 1024}


{'loss': 0.661, 'grad_norm': 4.074517250061035, 'learning_rate': 0.0, 'epoch': 50.0}


                                                 
Non-default generation parameters: {'max_length': 1024}


{'eval_loss': 0.7748748660087585, 'eval_accuracy': 0.8, 'eval_runtime': 3.6867, 'eval_samples_per_second': 27.124, 'eval_steps_per_second': 0.542, 'epoch': 50.0}


100%|██████████| 650/650 [28:48<00:00,  2.66s/it]

{'train_runtime': 1728.7528, 'train_samples_per_second': 23.109, 'train_steps_per_second': 0.376, 'train_loss': 1.177377460186298, 'epoch': 50.0}





TrainOutput(global_step=650, training_loss=1.177377460186298, metrics={'train_runtime': 1728.7528, 'train_samples_per_second': 23.109, 'train_steps_per_second': 0.376, 'total_flos': 2.708117737046016e+18, 'train_loss': 1.177377460186298, 'epoch': 50.0})

In [11]:
pred=trainer.predict(gtzan_encoded["test"].with_format("torch"))
pred.metrics

100%|██████████| 2/2 [00:00<00:00,  2.42it/s]


{'test_loss': 0.839857816696167,
 'test_accuracy': 0.84,
 'test_runtime': 3.4015,
 'test_samples_per_second': 29.399,
 'test_steps_per_second': 0.588}