In [37]:
#import required libraries
from datasets import load_dataset
from datasets import load_dataset_builder

import datasets
import numpy as np
import pandas as pd

In [38]:
ds_builder = load_dataset_builder('Falah/Alzheimer_MRI')
print(f"Dataset Name: {ds_builder.info.dataset_name}")
labels = ds_builder.info.features["label"].names
print(f"Class Labels: {labels}")

Dataset Name: alzheimer_mri
Class Labels: ['Mild_Demented', 'Moderate_Demented', 'Non_Demented', 'Very_Mild_Demented']


In [40]:
alzheimer = load_dataset('Falah/Alzheimer_MRI', split='train')
# test_dataset = load_dataset('Falah/Alzheimer_MRI', split="test")
# alzheimer = datasets.DatasetDict({"train":train_dataset,"test":test_dataset})

alzheimer = alzheimer.train_test_split(test_size=0.2)

In [41]:
import numpy as np
imgdata = alzheimer["train"][0]
img = imgdata['image']
imgnp = np.array(img)
print(imgnp.shape)

(128, 128)


In [42]:
#To get the labels from the training set:
labels = alzheimer["train"].features['label'].names #--> returns unique labels. 
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    #print(f'id: {i}, label:{label}')
    label2id[label] = str(i)
    id2label[str(i)] = label

In [43]:
#Print label ids to a label names:
print(f' Label-0 -- > {id2label[str(0)]} \n Label-1 -- > {id2label[str(1)]} \n Label-2 -- > {id2label[str(2)]} \n Label-3 -- > {id2label[str(3)]}')

 Label-0 -- > Mild_Demented 
 Label-1 -- > Moderate_Demented 
 Label-2 -- > Non_Demented 
 Label-3 -- > Very_Mild_Demented


In [44]:
#Next step is to load a ViT image processor to process the image into a tensor
from transformers import AutoImageProcessor
#checkpoint = 'google/vit-base-patch16-224-in21k' #works 
#checkpoint = 'google/vit-hybrid-base-bit-384' --> this can also be tried.
#checkpoint = 'microsoft/swin-base-patch4-window7-224-in22k' 
#checkpoint = 'microsoft/beit-base-patch16-384'  #cuda out of memory
#checkpoint = 'microsoft/swin-base-patch4-window7-224' #works
#checkpoint = 'facebook/deit-base-distilled-patch16-224'#error in training (labels not required)
#checkpoint = 'facebook/deit-base-patch16-224' #works
checkpoint = 'microsoft/beit-base-patch16-224-pt22k-ft22k'
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [45]:
s=checkpoint.split('/')
dest_dir = "models/" + checkpoint.split('/')[1]
dest_dir

'models/beit-base-patch16-224-pt22k-ft22k'

In [46]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)

size = (image_processor.size["shortest_edge"]
        if "shortest_edge" in image_processor.size
        else (image_processor.size["height"], image_processor.size["width"])
)

_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

print(_transforms)


Compose(
    RandomResizedCrop(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear, antialias=warn)
    ToTensor()
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
)


In [47]:
#Create a preprocessing function to apply the transforms and return the pixel_values - the
# inputs to the model -of the image:
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples


In [48]:
alzheimer = alzheimer.with_transform(transforms)
alzheimer

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 4096
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 1024
    })
})

In [35]:
#alzheimer['train'][:]['label']

alzheimer['train']['pixel_values']

In [49]:
import numpy as np
d=alzheimer['train'][0]
img = d['pixel_values']
img = np.array(img)
print(img.shape)

(3, 224, 224)


In [50]:
# Now create a batch of examples using DefaultDataCollator. Gap: (Now is this DataLoader of Pytorch?)
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()

In [51]:
import evaluate
accuracy = evaluate.load("accuracy")

In [52]:
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions,axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [53]:
from transformers import Trainer, AutoModelForImageClassification, TrainingArguments

model = AutoModelForImageClassification.from_pretrained(checkpoint,
                                                        num_labels=len(labels),
                                                        id2label = id2label,
                                                        label2id = label2id,
                                                        ignore_mismatched_sizes=True
                                                        )


Some weights of BeitForImageClassification were not initialized from the model checkpoint at microsoft/beit-base-patch16-224-pt22k-ft22k and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([21841, 768]) in the checkpoint and torch.Size([4, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([21841]) in the checkpoint and torch.Size([4]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [54]:
training_args = TrainingArguments(
    output_dir="models/" + checkpoint.split('/')[1],
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=50,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

In [55]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=alzheimer["train"],
    eval_dataset=alzheimer["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

In [56]:
trainer.train()

# trainer.evaluate() — The dataset to use for evaluation.
#  If it is a Dataset, columns not accepted by the model.forward() method are automatically removed. 
# If it is a dictionary, it will evaluate on each dataset prepending the dictionary key to the metric name.

  0%|          | 0/3200 [00:00<?, ?it/s]

{'loss': 1.3514, 'learning_rate': 1.5625e-06, 'epoch': 0.16}
{'loss': 1.1677, 'learning_rate': 3.125e-06, 'epoch': 0.31}
{'loss': 1.0548, 'learning_rate': 4.6875000000000004e-06, 'epoch': 0.47}
{'loss': 1.0066, 'learning_rate': 6.25e-06, 'epoch': 0.62}
{'loss': 0.9683, 'learning_rate': 7.8125e-06, 'epoch': 0.78}
{'loss': 0.9493, 'learning_rate': 9.375000000000001e-06, 'epoch': 0.94}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.9605082869529724, 'eval_accuracy': 0.564453125, 'eval_runtime': 10.5124, 'eval_samples_per_second': 97.409, 'eval_steps_per_second': 6.088, 'epoch': 1.0}
{'loss': 0.9266, 'learning_rate': 1.09375e-05, 'epoch': 1.09}
{'loss': 0.8731, 'learning_rate': 1.25e-05, 'epoch': 1.25}
{'loss': 0.9275, 'learning_rate': 1.4062500000000001e-05, 'epoch': 1.41}
{'loss': 0.9191, 'learning_rate': 1.5625e-05, 'epoch': 1.56}
{'loss': 0.8515, 'learning_rate': 1.71875e-05, 'epoch': 1.72}
{'loss': 0.8779, 'learning_rate': 1.8750000000000002e-05, 'epoch': 1.88}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.9015829563140869, 'eval_accuracy': 0.5615234375, 'eval_runtime': 10.5675, 'eval_samples_per_second': 96.901, 'eval_steps_per_second': 6.056, 'epoch': 2.0}
{'loss': 0.8536, 'learning_rate': 2.0312500000000002e-05, 'epoch': 2.03}
{'loss': 0.8536, 'learning_rate': 2.1875e-05, 'epoch': 2.19}
{'loss': 0.8854, 'learning_rate': 2.34375e-05, 'epoch': 2.34}
{'loss': 0.8177, 'learning_rate': 2.5e-05, 'epoch': 2.5}
{'loss': 0.8798, 'learning_rate': 2.6562500000000002e-05, 'epoch': 2.66}
{'loss': 0.8066, 'learning_rate': 2.8125000000000003e-05, 'epoch': 2.81}
{'loss': 0.8721, 'learning_rate': 2.96875e-05, 'epoch': 2.97}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.8284130692481995, 'eval_accuracy': 0.62890625, 'eval_runtime': 10.838, 'eval_samples_per_second': 94.483, 'eval_steps_per_second': 5.905, 'epoch': 3.0}
{'loss': 0.8568, 'learning_rate': 3.125e-05, 'epoch': 3.12}
{'loss': 0.7323, 'learning_rate': 3.2812500000000005e-05, 'epoch': 3.28}
{'loss': 0.7867, 'learning_rate': 3.4375e-05, 'epoch': 3.44}
{'loss': 0.8035, 'learning_rate': 3.59375e-05, 'epoch': 3.59}
{'loss': 0.8245, 'learning_rate': 3.7500000000000003e-05, 'epoch': 3.75}
{'loss': 0.741, 'learning_rate': 3.90625e-05, 'epoch': 3.91}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.8019998073577881, 'eval_accuracy': 0.6279296875, 'eval_runtime': 10.8617, 'eval_samples_per_second': 94.276, 'eval_steps_per_second': 5.892, 'epoch': 4.0}
{'loss': 0.8182, 'learning_rate': 4.0625000000000005e-05, 'epoch': 4.06}
{'loss': 0.7887, 'learning_rate': 4.21875e-05, 'epoch': 4.22}
{'loss': 0.8412, 'learning_rate': 4.375e-05, 'epoch': 4.38}
{'loss': 0.7982, 'learning_rate': 4.5312500000000004e-05, 'epoch': 4.53}
{'loss': 0.8104, 'learning_rate': 4.6875e-05, 'epoch': 4.69}
{'loss': 0.7437, 'learning_rate': 4.8437500000000005e-05, 'epoch': 4.84}
{'loss': 0.8038, 'learning_rate': 5e-05, 'epoch': 5.0}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.7774355411529541, 'eval_accuracy': 0.6572265625, 'eval_runtime': 10.8612, 'eval_samples_per_second': 94.281, 'eval_steps_per_second': 5.893, 'epoch': 5.0}
{'loss': 0.7368, 'learning_rate': 4.982638888888889e-05, 'epoch': 5.16}
{'loss': 0.7242, 'learning_rate': 4.965277777777778e-05, 'epoch': 5.31}
{'loss': 0.8471, 'learning_rate': 4.947916666666667e-05, 'epoch': 5.47}
{'loss': 0.6586, 'learning_rate': 4.930555555555556e-05, 'epoch': 5.62}
{'loss': 0.7452, 'learning_rate': 4.9131944444444445e-05, 'epoch': 5.78}
{'loss': 0.7183, 'learning_rate': 4.8958333333333335e-05, 'epoch': 5.94}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.7078962326049805, 'eval_accuracy': 0.6875, 'eval_runtime': 10.9719, 'eval_samples_per_second': 93.329, 'eval_steps_per_second': 5.833, 'epoch': 6.0}
{'loss': 0.6844, 'learning_rate': 4.8784722222222225e-05, 'epoch': 6.09}
{'loss': 0.6485, 'learning_rate': 4.8611111111111115e-05, 'epoch': 6.25}
{'loss': 0.696, 'learning_rate': 4.8437500000000005e-05, 'epoch': 6.41}
{'loss': 0.6506, 'learning_rate': 4.8263888888888895e-05, 'epoch': 6.56}
{'loss': 0.6575, 'learning_rate': 4.809027777777778e-05, 'epoch': 6.72}
{'loss': 0.607, 'learning_rate': 4.791666666666667e-05, 'epoch': 6.88}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.6658731698989868, 'eval_accuracy': 0.7158203125, 'eval_runtime': 10.9461, 'eval_samples_per_second': 93.549, 'eval_steps_per_second': 5.847, 'epoch': 7.0}
{'loss': 0.6484, 'learning_rate': 4.774305555555556e-05, 'epoch': 7.03}
{'loss': 0.555, 'learning_rate': 4.756944444444444e-05, 'epoch': 7.19}
{'loss': 0.5378, 'learning_rate': 4.739583333333333e-05, 'epoch': 7.34}
{'loss': 0.5462, 'learning_rate': 4.722222222222222e-05, 'epoch': 7.5}
{'loss': 0.538, 'learning_rate': 4.704861111111111e-05, 'epoch': 7.66}
{'loss': 0.5739, 'learning_rate': 4.6875e-05, 'epoch': 7.81}
{'loss': 0.5652, 'learning_rate': 4.670138888888889e-05, 'epoch': 7.97}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.5409120917320251, 'eval_accuracy': 0.7626953125, 'eval_runtime': 10.9546, 'eval_samples_per_second': 93.477, 'eval_steps_per_second': 5.842, 'epoch': 8.0}
{'loss': 0.4885, 'learning_rate': 4.652777777777778e-05, 'epoch': 8.12}
{'loss': 0.5004, 'learning_rate': 4.635416666666667e-05, 'epoch': 8.28}
{'loss': 0.5004, 'learning_rate': 4.618055555555556e-05, 'epoch': 8.44}
{'loss': 0.4355, 'learning_rate': 4.6006944444444444e-05, 'epoch': 8.59}
{'loss': 0.5371, 'learning_rate': 4.5833333333333334e-05, 'epoch': 8.75}
{'loss': 0.5312, 'learning_rate': 4.5659722222222224e-05, 'epoch': 8.91}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.6239457130432129, 'eval_accuracy': 0.740234375, 'eval_runtime': 10.6775, 'eval_samples_per_second': 95.903, 'eval_steps_per_second': 5.994, 'epoch': 9.0}
{'loss': 0.5393, 'learning_rate': 4.5486111111111114e-05, 'epoch': 9.06}
{'loss': 0.4482, 'learning_rate': 4.5312500000000004e-05, 'epoch': 9.22}
{'loss': 0.4603, 'learning_rate': 4.5138888888888894e-05, 'epoch': 9.38}
{'loss': 0.4352, 'learning_rate': 4.4965277777777784e-05, 'epoch': 9.53}
{'loss': 0.3981, 'learning_rate': 4.4791666666666673e-05, 'epoch': 9.69}
{'loss': 0.4876, 'learning_rate': 4.4618055555555563e-05, 'epoch': 9.84}
{'loss': 0.4272, 'learning_rate': 4.4444444444444447e-05, 'epoch': 10.0}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.4804430603981018, 'eval_accuracy': 0.8056640625, 'eval_runtime': 10.9024, 'eval_samples_per_second': 93.924, 'eval_steps_per_second': 5.87, 'epoch': 10.0}
{'loss': 0.3882, 'learning_rate': 4.4270833333333337e-05, 'epoch': 10.16}
{'loss': 0.4121, 'learning_rate': 4.4097222222222226e-05, 'epoch': 10.31}
{'loss': 0.437, 'learning_rate': 4.392361111111111e-05, 'epoch': 10.47}
{'loss': 0.4122, 'learning_rate': 4.375e-05, 'epoch': 10.62}
{'loss': 0.3656, 'learning_rate': 4.357638888888889e-05, 'epoch': 10.78}
{'loss': 0.3574, 'learning_rate': 4.340277777777778e-05, 'epoch': 10.94}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.4548543095588684, 'eval_accuracy': 0.8212890625, 'eval_runtime': 10.817, 'eval_samples_per_second': 94.666, 'eval_steps_per_second': 5.917, 'epoch': 11.0}
{'loss': 0.3654, 'learning_rate': 4.322916666666667e-05, 'epoch': 11.09}
{'loss': 0.4721, 'learning_rate': 4.305555555555556e-05, 'epoch': 11.25}
{'loss': 0.4084, 'learning_rate': 4.288194444444444e-05, 'epoch': 11.41}
{'loss': 0.4244, 'learning_rate': 4.270833333333333e-05, 'epoch': 11.56}
{'loss': 0.3493, 'learning_rate': 4.253472222222222e-05, 'epoch': 11.72}
{'loss': 0.3714, 'learning_rate': 4.236111111111111e-05, 'epoch': 11.88}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.4329036772251129, 'eval_accuracy': 0.8271484375, 'eval_runtime': 10.673, 'eval_samples_per_second': 95.943, 'eval_steps_per_second': 5.996, 'epoch': 12.0}
{'loss': 0.3381, 'learning_rate': 4.21875e-05, 'epoch': 12.03}
{'loss': 0.3675, 'learning_rate': 4.201388888888889e-05, 'epoch': 12.19}
{'loss': 0.4018, 'learning_rate': 4.184027777777778e-05, 'epoch': 12.34}
{'loss': 0.3378, 'learning_rate': 4.166666666666667e-05, 'epoch': 12.5}
{'loss': 0.248, 'learning_rate': 4.149305555555556e-05, 'epoch': 12.66}
{'loss': 0.2896, 'learning_rate': 4.1319444444444445e-05, 'epoch': 12.81}
{'loss': 0.3221, 'learning_rate': 4.1145833333333335e-05, 'epoch': 12.97}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.5023562908172607, 'eval_accuracy': 0.806640625, 'eval_runtime': 10.7452, 'eval_samples_per_second': 95.298, 'eval_steps_per_second': 5.956, 'epoch': 13.0}
{'loss': 0.3139, 'learning_rate': 4.0972222222222225e-05, 'epoch': 13.12}
{'loss': 0.2964, 'learning_rate': 4.0798611111111115e-05, 'epoch': 13.28}
{'loss': 0.3405, 'learning_rate': 4.0625000000000005e-05, 'epoch': 13.44}
{'loss': 0.2734, 'learning_rate': 4.045138888888889e-05, 'epoch': 13.59}
{'loss': 0.2737, 'learning_rate': 4.027777777777778e-05, 'epoch': 13.75}
{'loss': 0.2996, 'learning_rate': 4.010416666666667e-05, 'epoch': 13.91}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.42720139026641846, 'eval_accuracy': 0.82421875, 'eval_runtime': 10.8068, 'eval_samples_per_second': 94.755, 'eval_steps_per_second': 5.922, 'epoch': 14.0}
{'loss': 0.2812, 'learning_rate': 3.993055555555556e-05, 'epoch': 14.06}
{'loss': 0.2603, 'learning_rate': 3.975694444444444e-05, 'epoch': 14.22}
{'loss': 0.2821, 'learning_rate': 3.958333333333333e-05, 'epoch': 14.38}
{'loss': 0.2912, 'learning_rate': 3.940972222222222e-05, 'epoch': 14.53}
{'loss': 0.2728, 'learning_rate': 3.923611111111111e-05, 'epoch': 14.69}
{'loss': 0.3051, 'learning_rate': 3.90625e-05, 'epoch': 14.84}
{'loss': 0.2551, 'learning_rate': 3.888888888888889e-05, 'epoch': 15.0}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.4075433313846588, 'eval_accuracy': 0.8359375, 'eval_runtime': 10.7376, 'eval_samples_per_second': 95.366, 'eval_steps_per_second': 5.96, 'epoch': 15.0}
{'loss': 0.2676, 'learning_rate': 3.871527777777778e-05, 'epoch': 15.16}
{'loss': 0.2838, 'learning_rate': 3.854166666666667e-05, 'epoch': 15.31}
{'loss': 0.2546, 'learning_rate': 3.836805555555556e-05, 'epoch': 15.47}
{'loss': 0.2457, 'learning_rate': 3.8194444444444444e-05, 'epoch': 15.62}
{'loss': 0.3348, 'learning_rate': 3.8020833333333334e-05, 'epoch': 15.78}
{'loss': 0.2801, 'learning_rate': 3.7847222222222224e-05, 'epoch': 15.94}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.3377915918827057, 'eval_accuracy': 0.86328125, 'eval_runtime': 10.7769, 'eval_samples_per_second': 95.018, 'eval_steps_per_second': 5.939, 'epoch': 16.0}
{'loss': 0.2171, 'learning_rate': 3.7673611111111114e-05, 'epoch': 16.09}
{'loss': 0.2171, 'learning_rate': 3.7500000000000003e-05, 'epoch': 16.25}
{'loss': 0.2864, 'learning_rate': 3.7326388888888893e-05, 'epoch': 16.41}
{'loss': 0.2342, 'learning_rate': 3.715277777777778e-05, 'epoch': 16.56}
{'loss': 0.2617, 'learning_rate': 3.697916666666667e-05, 'epoch': 16.72}
{'loss': 0.2104, 'learning_rate': 3.6805555555555556e-05, 'epoch': 16.88}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.28229087591171265, 'eval_accuracy': 0.8896484375, 'eval_runtime': 10.7272, 'eval_samples_per_second': 95.458, 'eval_steps_per_second': 5.966, 'epoch': 17.0}
{'loss': 0.2253, 'learning_rate': 3.6631944444444446e-05, 'epoch': 17.03}
{'loss': 0.2829, 'learning_rate': 3.6458333333333336e-05, 'epoch': 17.19}
{'loss': 0.2787, 'learning_rate': 3.628472222222222e-05, 'epoch': 17.34}
{'loss': 0.2674, 'learning_rate': 3.611111111111111e-05, 'epoch': 17.5}
{'loss': 0.2352, 'learning_rate': 3.59375e-05, 'epoch': 17.66}
{'loss': 0.2186, 'learning_rate': 3.576388888888889e-05, 'epoch': 17.81}
{'loss': 0.1884, 'learning_rate': 3.559027777777778e-05, 'epoch': 17.97}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.2875107526779175, 'eval_accuracy': 0.888671875, 'eval_runtime': 10.7939, 'eval_samples_per_second': 94.868, 'eval_steps_per_second': 5.929, 'epoch': 18.0}
{'loss': 0.1845, 'learning_rate': 3.541666666666667e-05, 'epoch': 18.12}
{'loss': 0.2206, 'learning_rate': 3.524305555555556e-05, 'epoch': 18.28}
{'loss': 0.2521, 'learning_rate': 3.506944444444444e-05, 'epoch': 18.44}
{'loss': 0.2212, 'learning_rate': 3.489583333333333e-05, 'epoch': 18.59}
{'loss': 0.2383, 'learning_rate': 3.472222222222222e-05, 'epoch': 18.75}
{'loss': 0.2326, 'learning_rate': 3.454861111111111e-05, 'epoch': 18.91}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.3412787616252899, 'eval_accuracy': 0.8564453125, 'eval_runtime': 10.7153, 'eval_samples_per_second': 95.564, 'eval_steps_per_second': 5.973, 'epoch': 19.0}
{'loss': 0.2399, 'learning_rate': 3.4375e-05, 'epoch': 19.06}
{'loss': 0.2026, 'learning_rate': 3.420138888888889e-05, 'epoch': 19.22}
{'loss': 0.2069, 'learning_rate': 3.402777777777778e-05, 'epoch': 19.38}
{'loss': 0.216, 'learning_rate': 3.385416666666667e-05, 'epoch': 19.53}
{'loss': 0.2022, 'learning_rate': 3.368055555555556e-05, 'epoch': 19.69}
{'loss': 0.2419, 'learning_rate': 3.3506944444444445e-05, 'epoch': 19.84}
{'loss': 0.2156, 'learning_rate': 3.3333333333333335e-05, 'epoch': 20.0}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.2789207696914673, 'eval_accuracy': 0.8896484375, 'eval_runtime': 10.7832, 'eval_samples_per_second': 94.963, 'eval_steps_per_second': 5.935, 'epoch': 20.0}
{'loss': 0.1576, 'learning_rate': 3.3159722222222225e-05, 'epoch': 20.16}
{'loss': 0.1752, 'learning_rate': 3.2986111111111115e-05, 'epoch': 20.31}
{'loss': 0.2295, 'learning_rate': 3.2812500000000005e-05, 'epoch': 20.47}
{'loss': 0.2339, 'learning_rate': 3.263888888888889e-05, 'epoch': 20.62}
{'loss': 0.1888, 'learning_rate': 3.246527777777778e-05, 'epoch': 20.78}
{'loss': 0.2125, 'learning_rate': 3.229166666666667e-05, 'epoch': 20.94}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.27255624532699585, 'eval_accuracy': 0.890625, 'eval_runtime': 10.8097, 'eval_samples_per_second': 94.73, 'eval_steps_per_second': 5.921, 'epoch': 21.0}
{'loss': 0.1835, 'learning_rate': 3.211805555555556e-05, 'epoch': 21.09}
{'loss': 0.1807, 'learning_rate': 3.194444444444444e-05, 'epoch': 21.25}
{'loss': 0.1871, 'learning_rate': 3.177083333333333e-05, 'epoch': 21.41}
{'loss': 0.2577, 'learning_rate': 3.159722222222222e-05, 'epoch': 21.56}
{'loss': 0.1526, 'learning_rate': 3.142361111111111e-05, 'epoch': 21.72}
{'loss': 0.1842, 'learning_rate': 3.125e-05, 'epoch': 21.88}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.24645495414733887, 'eval_accuracy': 0.9013671875, 'eval_runtime': 10.7392, 'eval_samples_per_second': 95.352, 'eval_steps_per_second': 5.959, 'epoch': 22.0}
{'loss': 0.1816, 'learning_rate': 3.107638888888889e-05, 'epoch': 22.03}
{'loss': 0.1783, 'learning_rate': 3.090277777777778e-05, 'epoch': 22.19}
{'loss': 0.1651, 'learning_rate': 3.072916666666667e-05, 'epoch': 22.34}
{'loss': 0.1846, 'learning_rate': 3.055555555555556e-05, 'epoch': 22.5}
{'loss': 0.1866, 'learning_rate': 3.0381944444444444e-05, 'epoch': 22.66}
{'loss': 0.1633, 'learning_rate': 3.0208333333333334e-05, 'epoch': 22.81}
{'loss': 0.1775, 'learning_rate': 3.0034722222222223e-05, 'epoch': 22.97}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.2638329863548279, 'eval_accuracy': 0.9013671875, 'eval_runtime': 10.7383, 'eval_samples_per_second': 95.359, 'eval_steps_per_second': 5.96, 'epoch': 23.0}
{'loss': 0.1256, 'learning_rate': 2.9861111111111113e-05, 'epoch': 23.12}
{'loss': 0.1407, 'learning_rate': 2.96875e-05, 'epoch': 23.28}
{'loss': 0.2091, 'learning_rate': 2.951388888888889e-05, 'epoch': 23.44}
{'loss': 0.1778, 'learning_rate': 2.934027777777778e-05, 'epoch': 23.59}
{'loss': 0.163, 'learning_rate': 2.916666666666667e-05, 'epoch': 23.75}
{'loss': 0.1395, 'learning_rate': 2.899305555555556e-05, 'epoch': 23.91}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.24040023982524872, 'eval_accuracy': 0.9033203125, 'eval_runtime': 10.6811, 'eval_samples_per_second': 95.87, 'eval_steps_per_second': 5.992, 'epoch': 24.0}
{'loss': 0.1884, 'learning_rate': 2.8819444444444443e-05, 'epoch': 24.06}
{'loss': 0.1333, 'learning_rate': 2.8645833333333333e-05, 'epoch': 24.22}
{'loss': 0.175, 'learning_rate': 2.8472222222222223e-05, 'epoch': 24.38}
{'loss': 0.1466, 'learning_rate': 2.8298611111111113e-05, 'epoch': 24.53}
{'loss': 0.1775, 'learning_rate': 2.8125000000000003e-05, 'epoch': 24.69}
{'loss': 0.1532, 'learning_rate': 2.795138888888889e-05, 'epoch': 24.84}
{'loss': 0.1786, 'learning_rate': 2.777777777777778e-05, 'epoch': 25.0}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.2551637887954712, 'eval_accuracy': 0.9052734375, 'eval_runtime': 10.6351, 'eval_samples_per_second': 96.284, 'eval_steps_per_second': 6.018, 'epoch': 25.0}
{'loss': 0.1261, 'learning_rate': 2.760416666666667e-05, 'epoch': 25.16}
{'loss': 0.1488, 'learning_rate': 2.743055555555556e-05, 'epoch': 25.31}
{'loss': 0.1439, 'learning_rate': 2.7256944444444442e-05, 'epoch': 25.47}
{'loss': 0.1548, 'learning_rate': 2.7083333333333332e-05, 'epoch': 25.62}
{'loss': 0.1638, 'learning_rate': 2.6909722222222222e-05, 'epoch': 25.78}
{'loss': 0.177, 'learning_rate': 2.6736111111111112e-05, 'epoch': 25.94}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.23082773387432098, 'eval_accuracy': 0.921875, 'eval_runtime': 10.7664, 'eval_samples_per_second': 95.111, 'eval_steps_per_second': 5.944, 'epoch': 26.0}
{'loss': 0.1531, 'learning_rate': 2.6562500000000002e-05, 'epoch': 26.09}
{'loss': 0.1536, 'learning_rate': 2.6388888888888892e-05, 'epoch': 26.25}
{'loss': 0.1504, 'learning_rate': 2.6215277777777782e-05, 'epoch': 26.41}
{'loss': 0.1613, 'learning_rate': 2.604166666666667e-05, 'epoch': 26.56}
{'loss': 0.1872, 'learning_rate': 2.5868055555555558e-05, 'epoch': 26.72}
{'loss': 0.1718, 'learning_rate': 2.5694444444444445e-05, 'epoch': 26.88}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.20928910374641418, 'eval_accuracy': 0.91796875, 'eval_runtime': 10.6921, 'eval_samples_per_second': 95.771, 'eval_steps_per_second': 5.986, 'epoch': 27.0}
{'loss': 0.1467, 'learning_rate': 2.552083333333333e-05, 'epoch': 27.03}
{'loss': 0.176, 'learning_rate': 2.534722222222222e-05, 'epoch': 27.19}
{'loss': 0.1317, 'learning_rate': 2.517361111111111e-05, 'epoch': 27.34}
{'loss': 0.1278, 'learning_rate': 2.5e-05, 'epoch': 27.5}
{'loss': 0.1348, 'learning_rate': 2.482638888888889e-05, 'epoch': 27.66}
{'loss': 0.1263, 'learning_rate': 2.465277777777778e-05, 'epoch': 27.81}
{'loss': 0.1448, 'learning_rate': 2.4479166666666668e-05, 'epoch': 27.97}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.25165855884552, 'eval_accuracy': 0.904296875, 'eval_runtime': 10.697, 'eval_samples_per_second': 95.728, 'eval_steps_per_second': 5.983, 'epoch': 28.0}
{'loss': 0.12, 'learning_rate': 2.4305555555555558e-05, 'epoch': 28.12}
{'loss': 0.1795, 'learning_rate': 2.4131944444444448e-05, 'epoch': 28.28}
{'loss': 0.145, 'learning_rate': 2.3958333333333334e-05, 'epoch': 28.44}
{'loss': 0.1627, 'learning_rate': 2.378472222222222e-05, 'epoch': 28.59}
{'loss': 0.1738, 'learning_rate': 2.361111111111111e-05, 'epoch': 28.75}
{'loss': 0.1466, 'learning_rate': 2.34375e-05, 'epoch': 28.91}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.1945890635251999, 'eval_accuracy': 0.9287109375, 'eval_runtime': 10.7837, 'eval_samples_per_second': 94.958, 'eval_steps_per_second': 5.935, 'epoch': 29.0}
{'loss': 0.1514, 'learning_rate': 2.326388888888889e-05, 'epoch': 29.06}
{'loss': 0.1363, 'learning_rate': 2.309027777777778e-05, 'epoch': 29.22}
{'loss': 0.1371, 'learning_rate': 2.2916666666666667e-05, 'epoch': 29.38}
{'loss': 0.1449, 'learning_rate': 2.2743055555555557e-05, 'epoch': 29.53}
{'loss': 0.1429, 'learning_rate': 2.2569444444444447e-05, 'epoch': 29.69}
{'loss': 0.1406, 'learning_rate': 2.2395833333333337e-05, 'epoch': 29.84}
{'loss': 0.1303, 'learning_rate': 2.2222222222222223e-05, 'epoch': 30.0}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.21188655495643616, 'eval_accuracy': 0.923828125, 'eval_runtime': 10.7451, 'eval_samples_per_second': 95.299, 'eval_steps_per_second': 5.956, 'epoch': 30.0}
{'loss': 0.1024, 'learning_rate': 2.2048611111111113e-05, 'epoch': 30.16}
{'loss': 0.1247, 'learning_rate': 2.1875e-05, 'epoch': 30.31}
{'loss': 0.1156, 'learning_rate': 2.170138888888889e-05, 'epoch': 30.47}
{'loss': 0.1402, 'learning_rate': 2.152777777777778e-05, 'epoch': 30.62}
{'loss': 0.1327, 'learning_rate': 2.1354166666666666e-05, 'epoch': 30.78}
{'loss': 0.11, 'learning_rate': 2.1180555555555556e-05, 'epoch': 30.94}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.2311210185289383, 'eval_accuracy': 0.916015625, 'eval_runtime': 10.8703, 'eval_samples_per_second': 94.201, 'eval_steps_per_second': 5.888, 'epoch': 31.0}
{'loss': 0.1567, 'learning_rate': 2.1006944444444446e-05, 'epoch': 31.09}
{'loss': 0.1266, 'learning_rate': 2.0833333333333336e-05, 'epoch': 31.25}
{'loss': 0.1287, 'learning_rate': 2.0659722222222223e-05, 'epoch': 31.41}
{'loss': 0.1477, 'learning_rate': 2.0486111111111113e-05, 'epoch': 31.56}
{'loss': 0.1178, 'learning_rate': 2.0312500000000002e-05, 'epoch': 31.72}
{'loss': 0.1353, 'learning_rate': 2.013888888888889e-05, 'epoch': 31.88}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.30340099334716797, 'eval_accuracy': 0.9052734375, 'eval_runtime': 10.661, 'eval_samples_per_second': 96.051, 'eval_steps_per_second': 6.003, 'epoch': 32.0}
{'loss': 0.1496, 'learning_rate': 1.996527777777778e-05, 'epoch': 32.03}
{'loss': 0.1284, 'learning_rate': 1.9791666666666665e-05, 'epoch': 32.19}
{'loss': 0.0996, 'learning_rate': 1.9618055555555555e-05, 'epoch': 32.34}
{'loss': 0.1259, 'learning_rate': 1.9444444444444445e-05, 'epoch': 32.5}
{'loss': 0.1319, 'learning_rate': 1.9270833333333335e-05, 'epoch': 32.66}
{'loss': 0.1155, 'learning_rate': 1.9097222222222222e-05, 'epoch': 32.81}
{'loss': 0.0814, 'learning_rate': 1.8923611111111112e-05, 'epoch': 32.97}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.2152693271636963, 'eval_accuracy': 0.931640625, 'eval_runtime': 11.1294, 'eval_samples_per_second': 92.009, 'eval_steps_per_second': 5.751, 'epoch': 33.0}
{'loss': 0.1184, 'learning_rate': 1.8750000000000002e-05, 'epoch': 33.12}
{'loss': 0.1311, 'learning_rate': 1.857638888888889e-05, 'epoch': 33.28}
{'loss': 0.1427, 'learning_rate': 1.8402777777777778e-05, 'epoch': 33.44}
{'loss': 0.1309, 'learning_rate': 1.8229166666666668e-05, 'epoch': 33.59}
{'loss': 0.1198, 'learning_rate': 1.8055555555555555e-05, 'epoch': 33.75}
{'loss': 0.1104, 'learning_rate': 1.7881944444444445e-05, 'epoch': 33.91}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.1886013001203537, 'eval_accuracy': 0.93359375, 'eval_runtime': 11.0232, 'eval_samples_per_second': 92.895, 'eval_steps_per_second': 5.806, 'epoch': 34.0}
{'loss': 0.0879, 'learning_rate': 1.7708333333333335e-05, 'epoch': 34.06}
{'loss': 0.1337, 'learning_rate': 1.753472222222222e-05, 'epoch': 34.22}
{'loss': 0.1181, 'learning_rate': 1.736111111111111e-05, 'epoch': 34.38}
{'loss': 0.1242, 'learning_rate': 1.71875e-05, 'epoch': 34.53}
{'loss': 0.113, 'learning_rate': 1.701388888888889e-05, 'epoch': 34.69}
{'loss': 0.1313, 'learning_rate': 1.684027777777778e-05, 'epoch': 34.84}
{'loss': 0.1136, 'learning_rate': 1.6666666666666667e-05, 'epoch': 35.0}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.19725865125656128, 'eval_accuracy': 0.9296875, 'eval_runtime': 11.0743, 'eval_samples_per_second': 92.466, 'eval_steps_per_second': 5.779, 'epoch': 35.0}
{'loss': 0.1094, 'learning_rate': 1.6493055555555557e-05, 'epoch': 35.16}
{'loss': 0.0919, 'learning_rate': 1.6319444444444444e-05, 'epoch': 35.31}
{'loss': 0.0995, 'learning_rate': 1.6145833333333334e-05, 'epoch': 35.47}
{'loss': 0.0978, 'learning_rate': 1.597222222222222e-05, 'epoch': 35.62}
{'loss': 0.1209, 'learning_rate': 1.579861111111111e-05, 'epoch': 35.78}
{'loss': 0.1182, 'learning_rate': 1.5625e-05, 'epoch': 35.94}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.22977213561534882, 'eval_accuracy': 0.91796875, 'eval_runtime': 11.092, 'eval_samples_per_second': 92.319, 'eval_steps_per_second': 5.77, 'epoch': 36.0}
{'loss': 0.0978, 'learning_rate': 1.545138888888889e-05, 'epoch': 36.09}
{'loss': 0.0932, 'learning_rate': 1.527777777777778e-05, 'epoch': 36.25}
{'loss': 0.0926, 'learning_rate': 1.5104166666666667e-05, 'epoch': 36.41}
{'loss': 0.1123, 'learning_rate': 1.4930555555555557e-05, 'epoch': 36.56}
{'loss': 0.1648, 'learning_rate': 1.4756944444444445e-05, 'epoch': 36.72}
{'loss': 0.1172, 'learning_rate': 1.4583333333333335e-05, 'epoch': 36.88}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.1994379460811615, 'eval_accuracy': 0.923828125, 'eval_runtime': 10.8948, 'eval_samples_per_second': 93.989, 'eval_steps_per_second': 5.874, 'epoch': 37.0}
{'loss': 0.1084, 'learning_rate': 1.4409722222222221e-05, 'epoch': 37.03}
{'loss': 0.1252, 'learning_rate': 1.4236111111111111e-05, 'epoch': 37.19}
{'loss': 0.1164, 'learning_rate': 1.4062500000000001e-05, 'epoch': 37.34}
{'loss': 0.0933, 'learning_rate': 1.388888888888889e-05, 'epoch': 37.5}
{'loss': 0.0798, 'learning_rate': 1.371527777777778e-05, 'epoch': 37.66}
{'loss': 0.094, 'learning_rate': 1.3541666666666666e-05, 'epoch': 37.81}
{'loss': 0.1148, 'learning_rate': 1.3368055555555556e-05, 'epoch': 37.97}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.23331668972969055, 'eval_accuracy': 0.919921875, 'eval_runtime': 11.0365, 'eval_samples_per_second': 92.783, 'eval_steps_per_second': 5.799, 'epoch': 38.0}
{'loss': 0.1066, 'learning_rate': 1.3194444444444446e-05, 'epoch': 38.12}
{'loss': 0.1032, 'learning_rate': 1.3020833333333334e-05, 'epoch': 38.28}
{'loss': 0.0929, 'learning_rate': 1.2847222222222222e-05, 'epoch': 38.44}
{'loss': 0.107, 'learning_rate': 1.267361111111111e-05, 'epoch': 38.59}
{'loss': 0.1044, 'learning_rate': 1.25e-05, 'epoch': 38.75}
{'loss': 0.1012, 'learning_rate': 1.232638888888889e-05, 'epoch': 38.91}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.1653902232646942, 'eval_accuracy': 0.935546875, 'eval_runtime': 11.0968, 'eval_samples_per_second': 92.279, 'eval_steps_per_second': 5.767, 'epoch': 39.0}
{'loss': 0.1077, 'learning_rate': 1.2152777777777779e-05, 'epoch': 39.06}
{'loss': 0.1038, 'learning_rate': 1.1979166666666667e-05, 'epoch': 39.22}
{'loss': 0.0945, 'learning_rate': 1.1805555555555555e-05, 'epoch': 39.38}
{'loss': 0.1022, 'learning_rate': 1.1631944444444445e-05, 'epoch': 39.53}
{'loss': 0.0894, 'learning_rate': 1.1458333333333333e-05, 'epoch': 39.69}
{'loss': 0.116, 'learning_rate': 1.1284722222222223e-05, 'epoch': 39.84}
{'loss': 0.0864, 'learning_rate': 1.1111111111111112e-05, 'epoch': 40.0}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.1799265742301941, 'eval_accuracy': 0.9365234375, 'eval_runtime': 10.9593, 'eval_samples_per_second': 93.437, 'eval_steps_per_second': 5.84, 'epoch': 40.0}
{'loss': 0.0781, 'learning_rate': 1.09375e-05, 'epoch': 40.16}
{'loss': 0.1009, 'learning_rate': 1.076388888888889e-05, 'epoch': 40.31}
{'loss': 0.0927, 'learning_rate': 1.0590277777777778e-05, 'epoch': 40.47}
{'loss': 0.1007, 'learning_rate': 1.0416666666666668e-05, 'epoch': 40.62}
{'loss': 0.0953, 'learning_rate': 1.0243055555555556e-05, 'epoch': 40.78}
{'loss': 0.104, 'learning_rate': 1.0069444444444445e-05, 'epoch': 40.94}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.169968381524086, 'eval_accuracy': 0.9345703125, 'eval_runtime': 11.0485, 'eval_samples_per_second': 92.682, 'eval_steps_per_second': 5.793, 'epoch': 41.0}
{'loss': 0.0837, 'learning_rate': 9.895833333333333e-06, 'epoch': 41.09}
{'loss': 0.0764, 'learning_rate': 9.722222222222223e-06, 'epoch': 41.25}
{'loss': 0.0835, 'learning_rate': 9.548611111111111e-06, 'epoch': 41.41}
{'loss': 0.072, 'learning_rate': 9.375000000000001e-06, 'epoch': 41.56}
{'loss': 0.0748, 'learning_rate': 9.201388888888889e-06, 'epoch': 41.72}
{'loss': 0.0757, 'learning_rate': 9.027777777777777e-06, 'epoch': 41.88}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.23123861849308014, 'eval_accuracy': 0.9248046875, 'eval_runtime': 11.0923, 'eval_samples_per_second': 92.316, 'eval_steps_per_second': 5.77, 'epoch': 42.0}
{'loss': 0.0747, 'learning_rate': 8.854166666666667e-06, 'epoch': 42.03}
{'loss': 0.0879, 'learning_rate': 8.680555555555556e-06, 'epoch': 42.19}
{'loss': 0.0815, 'learning_rate': 8.506944444444445e-06, 'epoch': 42.34}
{'loss': 0.0716, 'learning_rate': 8.333333333333334e-06, 'epoch': 42.5}
{'loss': 0.0694, 'learning_rate': 8.159722222222222e-06, 'epoch': 42.66}
{'loss': 0.0925, 'learning_rate': 7.98611111111111e-06, 'epoch': 42.81}
{'loss': 0.0718, 'learning_rate': 7.8125e-06, 'epoch': 42.97}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.19643914699554443, 'eval_accuracy': 0.9287109375, 'eval_runtime': 11.1227, 'eval_samples_per_second': 92.064, 'eval_steps_per_second': 5.754, 'epoch': 43.0}
{'loss': 0.1056, 'learning_rate': 7.63888888888889e-06, 'epoch': 43.12}
{'loss': 0.0864, 'learning_rate': 7.465277777777778e-06, 'epoch': 43.28}
{'loss': 0.0999, 'learning_rate': 7.2916666666666674e-06, 'epoch': 43.44}
{'loss': 0.0746, 'learning_rate': 7.118055555555556e-06, 'epoch': 43.59}
{'loss': 0.0549, 'learning_rate': 6.944444444444445e-06, 'epoch': 43.75}
{'loss': 0.0745, 'learning_rate': 6.770833333333333e-06, 'epoch': 43.91}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.18794822692871094, 'eval_accuracy': 0.9443359375, 'eval_runtime': 11.178, 'eval_samples_per_second': 91.609, 'eval_steps_per_second': 5.726, 'epoch': 44.0}
{'loss': 0.0988, 'learning_rate': 6.597222222222223e-06, 'epoch': 44.06}
{'loss': 0.0851, 'learning_rate': 6.423611111111111e-06, 'epoch': 44.22}
{'loss': 0.0965, 'learning_rate': 6.25e-06, 'epoch': 44.38}
{'loss': 0.0665, 'learning_rate': 6.076388888888889e-06, 'epoch': 44.53}
{'loss': 0.0965, 'learning_rate': 5.902777777777778e-06, 'epoch': 44.69}
{'loss': 0.0804, 'learning_rate': 5.729166666666667e-06, 'epoch': 44.84}
{'loss': 0.0816, 'learning_rate': 5.555555555555556e-06, 'epoch': 45.0}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.15866482257843018, 'eval_accuracy': 0.9453125, 'eval_runtime': 11.0291, 'eval_samples_per_second': 92.845, 'eval_steps_per_second': 5.803, 'epoch': 45.0}
{'loss': 0.0898, 'learning_rate': 5.381944444444445e-06, 'epoch': 45.16}
{'loss': 0.1036, 'learning_rate': 5.208333333333334e-06, 'epoch': 45.31}
{'loss': 0.1058, 'learning_rate': 5.034722222222222e-06, 'epoch': 45.47}
{'loss': 0.0638, 'learning_rate': 4.861111111111111e-06, 'epoch': 45.62}
{'loss': 0.0847, 'learning_rate': 4.6875000000000004e-06, 'epoch': 45.78}
{'loss': 0.0707, 'learning_rate': 4.513888888888889e-06, 'epoch': 45.94}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.1483646035194397, 'eval_accuracy': 0.9423828125, 'eval_runtime': 10.8034, 'eval_samples_per_second': 94.785, 'eval_steps_per_second': 5.924, 'epoch': 46.0}
{'loss': 0.1092, 'learning_rate': 4.340277777777778e-06, 'epoch': 46.09}
{'loss': 0.0557, 'learning_rate': 4.166666666666667e-06, 'epoch': 46.25}
{'loss': 0.0709, 'learning_rate': 3.993055555555555e-06, 'epoch': 46.41}
{'loss': 0.0572, 'learning_rate': 3.819444444444445e-06, 'epoch': 46.56}
{'loss': 0.0655, 'learning_rate': 3.6458333333333337e-06, 'epoch': 46.72}
{'loss': 0.0873, 'learning_rate': 3.4722222222222224e-06, 'epoch': 46.88}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.1548689901828766, 'eval_accuracy': 0.94921875, 'eval_runtime': 10.6117, 'eval_samples_per_second': 96.498, 'eval_steps_per_second': 6.031, 'epoch': 47.0}
{'loss': 0.0848, 'learning_rate': 3.2986111111111115e-06, 'epoch': 47.03}
{'loss': 0.0525, 'learning_rate': 3.125e-06, 'epoch': 47.19}
{'loss': 0.0864, 'learning_rate': 2.951388888888889e-06, 'epoch': 47.34}
{'loss': 0.0752, 'learning_rate': 2.777777777777778e-06, 'epoch': 47.5}
{'loss': 0.078, 'learning_rate': 2.604166666666667e-06, 'epoch': 47.66}
{'loss': 0.0741, 'learning_rate': 2.4305555555555557e-06, 'epoch': 47.81}
{'loss': 0.0681, 'learning_rate': 2.2569444444444443e-06, 'epoch': 47.97}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.19879980385303497, 'eval_accuracy': 0.927734375, 'eval_runtime': 10.7018, 'eval_samples_per_second': 95.685, 'eval_steps_per_second': 5.98, 'epoch': 48.0}
{'loss': 0.0902, 'learning_rate': 2.0833333333333334e-06, 'epoch': 48.12}
{'loss': 0.0723, 'learning_rate': 1.9097222222222225e-06, 'epoch': 48.28}
{'loss': 0.1042, 'learning_rate': 1.7361111111111112e-06, 'epoch': 48.44}
{'loss': 0.111, 'learning_rate': 1.5625e-06, 'epoch': 48.59}
{'loss': 0.0786, 'learning_rate': 1.388888888888889e-06, 'epoch': 48.75}
{'loss': 0.0892, 'learning_rate': 1.2152777777777778e-06, 'epoch': 48.91}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.15349119901657104, 'eval_accuracy': 0.955078125, 'eval_runtime': 10.7715, 'eval_samples_per_second': 95.065, 'eval_steps_per_second': 5.942, 'epoch': 49.0}
{'loss': 0.057, 'learning_rate': 1.0416666666666667e-06, 'epoch': 49.06}
{'loss': 0.0512, 'learning_rate': 8.680555555555556e-07, 'epoch': 49.22}
{'loss': 0.091, 'learning_rate': 6.944444444444445e-07, 'epoch': 49.38}
{'loss': 0.0842, 'learning_rate': 5.208333333333334e-07, 'epoch': 49.53}
{'loss': 0.0851, 'learning_rate': 3.4722222222222224e-07, 'epoch': 49.69}
{'loss': 0.0722, 'learning_rate': 1.7361111111111112e-07, 'epoch': 49.84}
{'loss': 0.0725, 'learning_rate': 0.0, 'epoch': 50.0}


  0%|          | 0/64 [00:00<?, ?it/s]

{'eval_loss': 0.1159391850233078, 'eval_accuracy': 0.9638671875, 'eval_runtime': 10.8811, 'eval_samples_per_second': 94.108, 'eval_steps_per_second': 5.882, 'epoch': 50.0}
{'train_runtime': 5829.6633, 'train_samples_per_second': 35.131, 'train_steps_per_second': 0.549, 'train_loss': 0.27567509358748793, 'epoch': 50.0}


TrainOutput(global_step=3200, training_loss=0.27567509358748793, metrics={'train_runtime': 5829.6633, 'train_samples_per_second': 35.131, 'train_steps_per_second': 0.549, 'train_loss': 0.27567509358748793, 'epoch': 50.0})

In [35]:
#Read the Test Data and do the same preprocessing.
testdata = load_dataset('Falah/Alzheimer_MRI', split='test')
testdata = testdata.with_transform(transforms)

In [36]:
#Method-2: using evaluate (Gives overall accuracy)
results=trainer.evaluate(testdata)
## produces the below output ### 
{'eval_loss': 2.117784023284912,
 'eval_accuracy': 0.831,
 'eval_runtime': 310.6302,
 'eval_samples_per_second': 3.219,
 'eval_steps_per_second': 0.203,
 'epoch': 1.98}
print(results)
# df=pd.DataFrame(results)
# df.to_csv('/Results/evaluation_results.csv',index=False)

  0%|          | 0/80 [00:00<?, ?it/s]

{'eval_loss': 0.2076101303100586, 'eval_accuracy': 0.93125, 'eval_runtime': 13.0413, 'eval_samples_per_second': 98.15, 'eval_steps_per_second': 6.134, 'epoch': 50.0}


In [22]:
results

{'eval_loss': 0.3861716091632843,
 'eval_accuracy': 0.859375,
 'eval_runtime': 12.5942,
 'eval_samples_per_second': 101.634,
 'eval_steps_per_second': 6.352,
 'epoch': 20.0}

In [None]:
#use of pipeline to get the predictions --> Way Forward. 


In [18]:
history

TrainOutput(global_step=1280, training_loss=0.5180638987571001, metrics={'train_runtime': 2372.3217, 'train_samples_per_second': 34.532, 'train_steps_per_second': 0.54, 'train_loss': 0.5180638987571001, 'epoch': 20.0})

It depends on what you’d like to do, trainer.evaluate() will predict + compute metrics on your test set and trainer.predict() will only predict labels on your test set. However in case the test set also contains ground-truth labels, the latter will also compute metrics.

In [69]:
#Method-3: Using predict (Gives predicted label, probability values for each sample along with metrics)
results_prediction = trainer.predict(testdata)

  0%|          | 0/80 [00:00<?, ?it/s]

In [113]:
results_prediction

PredictionOutput(predictions=array([[-0.5591147 , -2.0389035 ,  1.242404  ,  0.98351747],
       [-1.0933479 , -1.7672849 ,  1.4555196 ,  0.76520914],
       [-1.2960124 , -1.39749   ,  2.033903  , -0.04442274],
       ...,
       [-1.2275118 , -1.7329051 ,  1.8181549 ,  0.63654417],
       [-1.1876274 , -1.6070359 ,  1.8790383 ,  0.40170136],
       [ 0.2324425 , -1.971694  ,  0.30193636,  1.0467229 ]],
      dtype=float32), label_ids=array([3, 0, 2, ..., 3, 2, 3], dtype=int64), metrics={'test_loss': 0.8645501136779785, 'test_accuracy': 0.584375, 'test_runtime': 12.1436, 'test_samples_per_second': 105.405, 'test_steps_per_second': 6.588})

In [85]:
prediction=results_prediction[1] # predicted labels
actual = np.array(testdata[:]['label']) #ground-truth labels
print(prediction.shape)
print(actual.shape)

(1280,)
(1280,)


In [116]:
proba = results_prediction[0]
#proba
#print(proba.shape)
prediction1=np.argmax(proba, axis=1)
#print(actual1.shape)
#print(proba[96])
#print(prediction1[96])

print(prediction[1:10])
print(prediction1[1:10])


[0 2 3 0 3 2 2 2 3]
[2 2 3 3 3 3 3 3 3]


In [108]:
print(proba[56])

print(np.argmax(proba[56]))


[ 0.73476785 -1.6395384  -0.03571419  0.6846446 ]
0
