In [None]:
from datasets import load_dataset
dataset = load_dataset("keremberke/chest-xray-classification",'full')
dataset.save_to_disk("./chest_xray_clasnsificatio")

In [None]:
from datasets import DatasetDict, concatenate_datasets

dataset = DatasetDict.load_from_disk("./chest_xray_clasnsificatio")

def select_samples(dataset, label_value, num_samples):
    filtered = dataset.filter(lambda example: example['labels'] == label_value)
    return filtered.select(range(min(num_samples, len(filtered))))

def create_balanced_subset(dataset, num_samples=100):
    balanced_dataset = {}
    for split in dataset.keys():
        label_0_samples = select_samples(dataset[split], label_value=0, num_samples=num_samples)
        label_1_samples = select_samples(dataset[split], label_value=1, num_samples=num_samples)
        balanced_dataset[split] = concatenate_datasets([label_0_samples, label_1_samples])
    return DatasetDict(balanced_dataset)

balanced_dataset = create_balanced_subset(dataset, num_samples=100)

print(balanced_dataset)


Filter:   0%|          | 0/1165 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1165 [00:00<?, ? examples/s]

Filter:   0%|          | 0/582 [00:00<?, ? examples/s]

Filter:   0%|          | 0/582 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['image_file_path', 'image', 'labels'],
        num_rows: 200
    })
    validation: Dataset({
        features: ['image_file_path', 'image', 'labels'],
        num_rows: 200
    })
    test: Dataset({
        features: ['image_file_path', 'image', 'labels'],
        num_rows: 200
    })
})


In [None]:
import torch
import torch.nn as nn
import accelerate
from torch.utils.data import DataLoader
from torchvision.transforms import (
    Compose, Resize, ToTensor, Normalize
)
from transformers import ViTModel,ViTConfig, ViTForImageClassification, Trainer, TrainingArguments
from datasets import DatasetDict
dataset=balanced_dataset
image_size = 224
transforms = Compose([
    Resize((image_size, image_size)),
    ToTensor(),
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def preprocess_data(example):
    example['pixel_values'] = transforms(example['image'])
    return example

dataset = dataset.map(preprocess_data, batched=False)

dataset = dataset.remove_columns(["image_file_path", "image"])

print(dataset)

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['labels', 'pixel_values'],
        num_rows: 200
    })
    validation: Dataset({
        features: ['labels', 'pixel_values'],
        num_rows: 200
    })
    test: Dataset({
        features: ['labels', 'pixel_values'],
        num_rows: 200
    })
})


In [None]:
import torch
import torch.nn as nn
from transformers import ViTConfig, ViTForImageClassification
from transformers import Trainer, TrainingArguments
from datasets import DatasetDict


class ConvStem(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=7, stride=2, padding=3, output_size=None):
        super(ConvStem, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.output_size = output_size 

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.pool(x)
        if self.output_size:
            x = nn.functional.interpolate(x, size=self.output_size, mode='bilinear', align_corners=False)
        return x


class ViTWithConvStem(nn.Module):
    def __init__(self, config, conv_out_channels=64):
        super().__init__()
        patch_size = config.patch_size
        image_size = config.image_size
        aligned_size = (image_size // patch_size) * patch_size

        self.conv_stem = ConvStem(
            in_channels=config.num_channels,
            out_channels=conv_out_channels,
            kernel_size=7,
            stride=2,
            padding=3,
            output_size=(aligned_size, aligned_size), 
        )
        self.vit = ViTForImageClassification(config)

    def forward(self, pixel_values, labels=None):
        x = self.conv_stem(pixel_values)

        batch_size, channels, height, width = x.shape
        patch_size = self.vit.config.patch_size

        x = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
        x = x.permute(0, 2, 3, 1, 4, 5).contiguous()
        x = x.view(batch_size, -1, patch_size * patch_size * channels)

        outputs = self.vit(pixel_values=pixel_values, labels=labels)
        return outputs


config = ViTConfig(
    image_size=224,
    num_channels=3,
    patch_size=16,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    num_labels=2
)


model = ViTWithConvStem(config)

training_args = TrainingArguments(
    output_dir="./vit_model_1201",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    learning_rate=5e-5,
    weight_decay=0.01,
    save_total_limit=2,
    remove_unused_columns=False,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False
)

from evaluate import load
metric = load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = torch.argmax(torch.tensor(logits), dim=-1)
    return metric.compute(predictions=predictions.numpy(), references=labels)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=None,
    compute_metrics=compute_metrics,
)

trainer.train()
results = trainer.evaluate(dataset["test"])
print("Test Results:", results)


  trainer = Trainer(


  0%|          | 0/130 [00:00<?, ?it/s]

{'loss': 1.1706, 'grad_norm': 4.27631950378418, 'learning_rate': 4.615384615384616e-05, 'epoch': 0.77}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.7703848481178284, 'eval_accuracy': 0.5, 'eval_runtime': 28.7309, 'eval_samples_per_second': 6.961, 'eval_steps_per_second': 0.452, 'epoch': 1.0}
{'loss': 0.7787, 'grad_norm': 12.118109703063965, 'learning_rate': 4.230769230769231e-05, 'epoch': 1.54}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.6984092593193054, 'eval_accuracy': 0.57, 'eval_runtime': 28.6261, 'eval_samples_per_second': 6.987, 'eval_steps_per_second': 0.454, 'epoch': 2.0}
{'loss': 0.7181, 'grad_norm': 4.928044319152832, 'learning_rate': 3.846153846153846e-05, 'epoch': 2.31}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.659909725189209, 'eval_accuracy': 0.54, 'eval_runtime': 29.387, 'eval_samples_per_second': 6.806, 'eval_steps_per_second': 0.442, 'epoch': 3.0}
{'loss': 0.6674, 'grad_norm': 12.840006828308105, 'learning_rate': 3.461538461538462e-05, 'epoch': 3.08}
{'loss': 0.6399, 'grad_norm': 19.348966598510742, 'learning_rate': 3.0769230769230774e-05, 'epoch': 3.85}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.514397382736206, 'eval_accuracy': 0.765, 'eval_runtime': 30.4489, 'eval_samples_per_second': 6.568, 'eval_steps_per_second': 0.427, 'epoch': 4.0}
{'loss': 0.7666, 'grad_norm': 11.542110443115234, 'learning_rate': 2.6923076923076923e-05, 'epoch': 4.62}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.5182923674583435, 'eval_accuracy': 0.76, 'eval_runtime': 29.0403, 'eval_samples_per_second': 6.887, 'eval_steps_per_second': 0.448, 'epoch': 5.0}
{'loss': 0.4801, 'grad_norm': 15.616353988647461, 'learning_rate': 2.307692307692308e-05, 'epoch': 5.38}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.5544602274894714, 'eval_accuracy': 0.77, 'eval_runtime': 28.7717, 'eval_samples_per_second': 6.951, 'eval_steps_per_second': 0.452, 'epoch': 6.0}
{'loss': 0.3812, 'grad_norm': 19.119531631469727, 'learning_rate': 1.923076923076923e-05, 'epoch': 6.15}
{'loss': 0.2709, 'grad_norm': 4.7736124992370605, 'learning_rate': 1.5384615384615387e-05, 'epoch': 6.92}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.9060876369476318, 'eval_accuracy': 0.76, 'eval_runtime': 29.3644, 'eval_samples_per_second': 6.811, 'eval_steps_per_second': 0.443, 'epoch': 7.0}
{'loss': 0.2451, 'grad_norm': 11.232674598693848, 'learning_rate': 1.153846153846154e-05, 'epoch': 7.69}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.6265371441841125, 'eval_accuracy': 0.81, 'eval_runtime': 29.2492, 'eval_samples_per_second': 6.838, 'eval_steps_per_second': 0.444, 'epoch': 8.0}
{'loss': 0.3545, 'grad_norm': 2.2830491065979004, 'learning_rate': 7.692307692307694e-06, 'epoch': 8.46}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.6674125790596008, 'eval_accuracy': 0.78, 'eval_runtime': 29.1973, 'eval_samples_per_second': 6.85, 'eval_steps_per_second': 0.445, 'epoch': 9.0}
{'loss': 0.1964, 'grad_norm': 21.13983917236328, 'learning_rate': 3.846153846153847e-06, 'epoch': 9.23}
{'loss': 0.2211, 'grad_norm': 25.287439346313477, 'learning_rate': 0.0, 'epoch': 10.0}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.4938441514968872, 'eval_accuracy': 0.84, 'eval_runtime': 29.3959, 'eval_samples_per_second': 6.804, 'eval_steps_per_second': 0.442, 'epoch': 10.0}
{'train_runtime': 942.5098, 'train_samples_per_second': 2.122, 'train_steps_per_second': 0.138, 'train_loss': 0.5300321899927579, 'epoch': 10.0}


  0%|          | 0/13 [00:00<?, ?it/s]

Test Results: {'eval_loss': 0.716962993144989, 'eval_accuracy': 0.805, 'eval_runtime': 28.9942, 'eval_samples_per_second': 6.898, 'eval_steps_per_second': 0.448, 'epoch': 10.0}


In [8]:
from transformers import EarlyStoppingCallback

early_stopping = EarlyStoppingCallback(early_stopping_patience=3)

training_args = TrainingArguments(
    output_dir="./vit_model_1201",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=15,
    learning_rate=4e-6,
    weight_decay=0.01,
    save_total_limit=2,
    remove_unused_columns=False,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False
)



trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=None,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping],
)

trainer.train()
results = trainer.evaluate(dataset["test"])
print("Test Results:", results)

  trainer = Trainer(


  0%|          | 0/195 [00:00<?, ?it/s]

{'loss': 0.2072, 'grad_norm': 3.23065447807312, 'learning_rate': 3.7948717948717945e-06, 'epoch': 0.77}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.4808088541030884, 'eval_accuracy': 0.84, 'eval_runtime': 29.1606, 'eval_samples_per_second': 6.859, 'eval_steps_per_second': 0.446, 'epoch': 1.0}
{'loss': 0.1822, 'grad_norm': 7.942155838012695, 'learning_rate': 3.5897435897435896e-06, 'epoch': 1.54}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.6042740345001221, 'eval_accuracy': 0.82, 'eval_runtime': 28.9353, 'eval_samples_per_second': 6.912, 'eval_steps_per_second': 0.449, 'epoch': 2.0}
{'loss': 0.119, 'grad_norm': 1.5368516445159912, 'learning_rate': 3.3846153846153843e-06, 'epoch': 2.31}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.6628224849700928, 'eval_accuracy': 0.82, 'eval_runtime': 28.1608, 'eval_samples_per_second': 7.102, 'eval_steps_per_second': 0.462, 'epoch': 3.0}
{'loss': 0.1951, 'grad_norm': 1.374727487564087, 'learning_rate': 3.179487179487179e-06, 'epoch': 3.08}
{'loss': 0.199, 'grad_norm': 20.743125915527344, 'learning_rate': 2.974358974358974e-06, 'epoch': 3.85}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.5558016300201416, 'eval_accuracy': 0.86, 'eval_runtime': 28.5721, 'eval_samples_per_second': 7.0, 'eval_steps_per_second': 0.455, 'epoch': 4.0}
{'loss': 0.1692, 'grad_norm': 15.532220840454102, 'learning_rate': 2.769230769230769e-06, 'epoch': 4.62}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.7302120327949524, 'eval_accuracy': 0.83, 'eval_runtime': 28.6722, 'eval_samples_per_second': 6.975, 'eval_steps_per_second': 0.453, 'epoch': 5.0}
{'loss': 0.162, 'grad_norm': 13.735671043395996, 'learning_rate': 2.5641025641025644e-06, 'epoch': 5.38}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.5948677062988281, 'eval_accuracy': 0.855, 'eval_runtime': 28.3314, 'eval_samples_per_second': 7.059, 'eval_steps_per_second': 0.459, 'epoch': 6.0}
{'loss': 0.1151, 'grad_norm': 14.335393905639648, 'learning_rate': 2.358974358974359e-06, 'epoch': 6.15}
{'loss': 0.1367, 'grad_norm': 2.18445086479187, 'learning_rate': 2.1538461538461538e-06, 'epoch': 6.92}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.6589303016662598, 'eval_accuracy': 0.845, 'eval_runtime': 29.1401, 'eval_samples_per_second': 6.863, 'eval_steps_per_second': 0.446, 'epoch': 7.0}
{'train_runtime': 656.003, 'train_samples_per_second': 4.573, 'train_steps_per_second': 0.297, 'train_loss': 0.1633449954768786, 'epoch': 7.0}


  0%|          | 0/13 [00:00<?, ?it/s]

Test Results: {'eval_loss': 0.8433399200439453, 'eval_accuracy': 0.815, 'eval_runtime': 28.5323, 'eval_samples_per_second': 7.01, 'eval_steps_per_second': 0.456, 'epoch': 7.0}


In [None]:
dataset = DatasetDict.load_from_disk("./chest_xray_clasnsificatio")

image_size = 224
transforms = Compose([
    Resize((image_size, image_size)),
    ToTensor(),
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def preprocess_data(example):
    example['pixel_values'] = transforms(example['image'])
    return example

dataset = dataset.map(preprocess_data, batched=False)

dataset = dataset.remove_columns(["image_file_path", "image"])

print(dataset)

DatasetDict({
    train: Dataset({
        features: ['labels', 'pixel_values'],
        num_rows: 4077
    })
    validation: Dataset({
        features: ['labels', 'pixel_values'],
        num_rows: 1165
    })
    test: Dataset({
        features: ['labels', 'pixel_values'],
        num_rows: 582
    })
})


In [10]:

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=None,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping],
)

trainer.train()
results = trainer.evaluate(dataset["test"])
print("Test Results:", results)



  trainer = Trainer(


  0%|          | 0/3825 [00:00<?, ?it/s]

{'loss': 0.5214, 'grad_norm': 10.320158004760742, 'learning_rate': 3.989542483660131e-06, 'epoch': 0.04}
{'loss': 0.457, 'grad_norm': 13.489970207214355, 'learning_rate': 3.979084967320261e-06, 'epoch': 0.08}
{'loss': 0.3261, 'grad_norm': 5.10440731048584, 'learning_rate': 3.968627450980392e-06, 'epoch': 0.12}
{'loss': 0.2952, 'grad_norm': 4.566559791564941, 'learning_rate': 3.958169934640523e-06, 'epoch': 0.16}
{'loss': 0.4452, 'grad_norm': 8.832679748535156, 'learning_rate': 3.947712418300654e-06, 'epoch': 0.2}
{'loss': 0.2808, 'grad_norm': 14.268045425415039, 'learning_rate': 3.937254901960784e-06, 'epoch': 0.24}
{'loss': 0.3413, 'grad_norm': 3.437281608581543, 'learning_rate': 3.9267973856209145e-06, 'epoch': 0.27}
{'loss': 0.3436, 'grad_norm': 3.749412775039673, 'learning_rate': 3.916339869281045e-06, 'epoch': 0.31}
{'loss': 0.3351, 'grad_norm': 4.614947319030762, 'learning_rate': 3.905882352941176e-06, 'epoch': 0.35}
{'loss': 0.3635, 'grad_norm': 6.122328758239746, 'learning_rate

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.2900601327419281, 'eval_accuracy': 0.8755364806866953, 'eval_runtime': 166.9877, 'eval_samples_per_second': 6.977, 'eval_steps_per_second': 0.437, 'epoch': 1.0}
{'loss': 0.2952, 'grad_norm': 10.296650886535645, 'learning_rate': 3.7281045751633987e-06, 'epoch': 1.02}
{'loss': 0.2965, 'grad_norm': 2.2865076065063477, 'learning_rate': 3.7176470588235295e-06, 'epoch': 1.06}
{'loss': 0.4124, 'grad_norm': 24.955238342285156, 'learning_rate': 3.7071895424836603e-06, 'epoch': 1.1}
{'loss': 0.2849, 'grad_norm': 6.959079265594482, 'learning_rate': 3.696732026143791e-06, 'epoch': 1.14}
{'loss': 0.2876, 'grad_norm': 11.02238655090332, 'learning_rate': 3.686274509803921e-06, 'epoch': 1.18}
{'loss': 0.3225, 'grad_norm': 10.804498672485352, 'learning_rate': 3.675816993464052e-06, 'epoch': 1.22}
{'loss': 0.3173, 'grad_norm': 13.52967643737793, 'learning_rate': 3.6653594771241826e-06, 'epoch': 1.25}
{'loss': 0.2775, 'grad_norm': 16.68783187866211, 'learning_rate': 3.6549019607843134e-06

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.2475915402173996, 'eval_accuracy': 0.8969957081545065, 'eval_runtime': 168.3556, 'eval_samples_per_second': 6.92, 'eval_steps_per_second': 0.434, 'epoch': 2.0}
{'loss': 0.2754, 'grad_norm': 4.914225101470947, 'learning_rate': 3.4562091503267976e-06, 'epoch': 2.04}
{'loss': 0.2733, 'grad_norm': 5.5080437660217285, 'learning_rate': 3.445751633986928e-06, 'epoch': 2.08}
{'loss': 0.2625, 'grad_norm': 18.395261764526367, 'learning_rate': 3.4352941176470583e-06, 'epoch': 2.12}
{'loss': 0.3188, 'grad_norm': 27.508745193481445, 'learning_rate': 3.424836601307189e-06, 'epoch': 2.16}
{'loss': 0.2197, 'grad_norm': 4.140567302703857, 'learning_rate': 3.41437908496732e-06, 'epoch': 2.2}
{'loss': 0.3062, 'grad_norm': 8.938861846923828, 'learning_rate': 3.4039215686274507e-06, 'epoch': 2.24}
{'loss': 0.1869, 'grad_norm': 3.8084447383880615, 'learning_rate': 3.3934640522875815e-06, 'epoch': 2.27}
{'loss': 0.2153, 'grad_norm': 10.288230895996094, 'learning_rate': 3.3830065359477123e-06,

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.22645214200019836, 'eval_accuracy': 0.911587982832618, 'eval_runtime': 169.0437, 'eval_samples_per_second': 6.892, 'eval_steps_per_second': 0.432, 'epoch': 3.0}
{'loss': 0.1329, 'grad_norm': 2.8249239921569824, 'learning_rate': 3.1947712418300653e-06, 'epoch': 3.02}
{'loss': 0.285, 'grad_norm': 11.341493606567383, 'learning_rate': 3.1843137254901956e-06, 'epoch': 3.06}
{'loss': 0.2095, 'grad_norm': 5.133615493774414, 'learning_rate': 3.1738562091503264e-06, 'epoch': 3.1}
{'loss': 0.354, 'grad_norm': 19.815032958984375, 'learning_rate': 3.163398692810457e-06, 'epoch': 3.14}
{'loss': 0.2497, 'grad_norm': 16.99631118774414, 'learning_rate': 3.152941176470588e-06, 'epoch': 3.18}
{'loss': 0.2195, 'grad_norm': 7.619422435760498, 'learning_rate': 3.1424836601307188e-06, 'epoch': 3.22}
{'loss': 0.2157, 'grad_norm': 6.248483180999756, 'learning_rate': 3.1320261437908496e-06, 'epoch': 3.25}
{'loss': 0.2653, 'grad_norm': 3.9977824687957764, 'learning_rate': 3.1215686274509804e-06,

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.22256095707416534, 'eval_accuracy': 0.9064377682403434, 'eval_runtime': 170.0531, 'eval_samples_per_second': 6.851, 'eval_steps_per_second': 0.429, 'epoch': 4.0}
{'loss': 0.225, 'grad_norm': 21.615137100219727, 'learning_rate': 2.9228758169934637e-06, 'epoch': 4.04}
{'loss': 0.2013, 'grad_norm': 11.737386703491211, 'learning_rate': 2.9124183006535945e-06, 'epoch': 4.08}
{'loss': 0.2494, 'grad_norm': 8.04926586151123, 'learning_rate': 2.9019607843137253e-06, 'epoch': 4.12}
{'loss': 0.1915, 'grad_norm': 13.797473907470703, 'learning_rate': 2.891503267973856e-06, 'epoch': 4.16}
{'loss': 0.1942, 'grad_norm': 3.067655086517334, 'learning_rate': 2.881045751633987e-06, 'epoch': 4.2}
{'loss': 0.2055, 'grad_norm': 23.865129470825195, 'learning_rate': 2.8705882352941177e-06, 'epoch': 4.24}
{'loss': 0.1772, 'grad_norm': 8.667461395263672, 'learning_rate': 2.8601307189542485e-06, 'epoch': 4.27}
{'loss': 0.2389, 'grad_norm': 9.793421745300293, 'learning_rate': 2.8496732026143792e-06

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.2013339102268219, 'eval_accuracy': 0.9201716738197425, 'eval_runtime': 173.7223, 'eval_samples_per_second': 6.706, 'eval_steps_per_second': 0.42, 'epoch': 5.0}
{'loss': 0.1526, 'grad_norm': 24.68008041381836, 'learning_rate': 2.661437908496732e-06, 'epoch': 5.02}
{'loss': 0.2785, 'grad_norm': 11.892131805419922, 'learning_rate': 2.6509803921568626e-06, 'epoch': 5.06}
{'loss': 0.153, 'grad_norm': 10.179633140563965, 'learning_rate': 2.6405228758169934e-06, 'epoch': 5.1}
{'loss': 0.2072, 'grad_norm': 10.348719596862793, 'learning_rate': 2.630065359477124e-06, 'epoch': 5.14}
{'loss': 0.2089, 'grad_norm': 17.510812759399414, 'learning_rate': 2.619607843137255e-06, 'epoch': 5.18}
{'loss': 0.2862, 'grad_norm': 32.28323745727539, 'learning_rate': 2.6091503267973858e-06, 'epoch': 5.22}
{'loss': 0.2473, 'grad_norm': 14.37722110748291, 'learning_rate': 2.598692810457516e-06, 'epoch': 5.25}
{'loss': 0.1857, 'grad_norm': 5.4232378005981445, 'learning_rate': 2.588235294117647e-06, '

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.1890980452299118, 'eval_accuracy': 0.927038626609442, 'eval_runtime': 173.0074, 'eval_samples_per_second': 6.734, 'eval_steps_per_second': 0.422, 'epoch': 6.0}
{'loss': 0.1905, 'grad_norm': 9.21054458618164, 'learning_rate': 2.3895424836601307e-06, 'epoch': 6.04}
{'loss': 0.1918, 'grad_norm': 5.767477512359619, 'learning_rate': 2.3790849673202615e-06, 'epoch': 6.08}
{'loss': 0.1513, 'grad_norm': 29.399259567260742, 'learning_rate': 2.3686274509803923e-06, 'epoch': 6.12}
{'loss': 0.0793, 'grad_norm': 9.706868171691895, 'learning_rate': 2.3581699346405226e-06, 'epoch': 6.16}
{'loss': 0.1973, 'grad_norm': 13.061738014221191, 'learning_rate': 2.3477124183006534e-06, 'epoch': 6.2}
{'loss': 0.1188, 'grad_norm': 19.511974334716797, 'learning_rate': 2.3372549019607842e-06, 'epoch': 6.24}
{'loss': 0.1601, 'grad_norm': 2.9338784217834473, 'learning_rate': 2.326797385620915e-06, 'epoch': 6.27}
{'loss': 0.2011, 'grad_norm': 110.47612762451172, 'learning_rate': 2.316339869281046e-06

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.20688769221305847, 'eval_accuracy': 0.9167381974248927, 'eval_runtime': 173.9126, 'eval_samples_per_second': 6.699, 'eval_steps_per_second': 0.42, 'epoch': 7.0}
{'loss': 0.3348, 'grad_norm': 45.089256286621094, 'learning_rate': 2.1281045751633988e-06, 'epoch': 7.02}
{'loss': 0.2477, 'grad_norm': 30.357824325561523, 'learning_rate': 2.117647058823529e-06, 'epoch': 7.06}
{'loss': 0.2128, 'grad_norm': 22.115581512451172, 'learning_rate': 2.10718954248366e-06, 'epoch': 7.1}
{'loss': 0.168, 'grad_norm': 42.2343864440918, 'learning_rate': 2.0967320261437907e-06, 'epoch': 7.14}
{'loss': 0.1488, 'grad_norm': 15.824294090270996, 'learning_rate': 2.0862745098039215e-06, 'epoch': 7.18}
{'loss': 0.1761, 'grad_norm': 16.50252342224121, 'learning_rate': 2.0758169934640523e-06, 'epoch': 7.22}
{'loss': 0.2273, 'grad_norm': 25.550832748413086, 'learning_rate': 2.065359477124183e-06, 'epoch': 7.25}
{'loss': 0.1694, 'grad_norm': 9.370136260986328, 'learning_rate': 2.0549019607843135e-06, 

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.18939606845378876, 'eval_accuracy': 0.934763948497854, 'eval_runtime': 176.4509, 'eval_samples_per_second': 6.602, 'eval_steps_per_second': 0.414, 'epoch': 8.0}
{'loss': 0.24, 'grad_norm': 5.828797340393066, 'learning_rate': 1.8562091503267974e-06, 'epoch': 8.04}
{'loss': 0.1517, 'grad_norm': 8.457097053527832, 'learning_rate': 1.845751633986928e-06, 'epoch': 8.08}
{'loss': 0.0856, 'grad_norm': 1.0516471862792969, 'learning_rate': 1.8352941176470586e-06, 'epoch': 8.12}
{'loss': 0.1758, 'grad_norm': 12.280799865722656, 'learning_rate': 1.8248366013071894e-06, 'epoch': 8.16}
{'loss': 0.0824, 'grad_norm': 3.064483404159546, 'learning_rate': 1.8143790849673202e-06, 'epoch': 8.2}
{'loss': 0.1372, 'grad_norm': 17.69049644470215, 'learning_rate': 1.803921568627451e-06, 'epoch': 8.24}
{'loss': 0.2963, 'grad_norm': 10.441651344299316, 'learning_rate': 1.7934640522875818e-06, 'epoch': 8.27}
{'loss': 0.1867, 'grad_norm': 43.7600212097168, 'learning_rate': 1.7830065359477123e-06, '

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.19551870226860046, 'eval_accuracy': 0.9313304721030042, 'eval_runtime': 178.3785, 'eval_samples_per_second': 6.531, 'eval_steps_per_second': 0.409, 'epoch': 9.0}
{'loss': 0.2643, 'grad_norm': 0.5015349984169006, 'learning_rate': 1.5947712418300653e-06, 'epoch': 9.02}
{'loss': 0.1287, 'grad_norm': 27.827131271362305, 'learning_rate': 1.584313725490196e-06, 'epoch': 9.06}
{'loss': 0.2494, 'grad_norm': 27.059717178344727, 'learning_rate': 1.5738562091503267e-06, 'epoch': 9.1}
{'loss': 0.1579, 'grad_norm': 63.75535583496094, 'learning_rate': 1.5633986928104575e-06, 'epoch': 9.14}
{'loss': 0.0746, 'grad_norm': 11.734855651855469, 'learning_rate': 1.5529411764705883e-06, 'epoch': 9.18}
{'loss': 0.1533, 'grad_norm': 33.630393981933594, 'learning_rate': 1.5424836601307189e-06, 'epoch': 9.22}
{'loss': 0.1702, 'grad_norm': 16.898588180541992, 'learning_rate': 1.5320261437908496e-06, 'epoch': 9.25}
{'loss': 0.155, 'grad_norm': 10.35288143157959, 'learning_rate': 1.5215686274509802

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.201002299785614, 'eval_accuracy': 0.936480686695279, 'eval_runtime': 179.2001, 'eval_samples_per_second': 6.501, 'eval_steps_per_second': 0.407, 'epoch': 10.0}
{'loss': 0.1635, 'grad_norm': 70.19966125488281, 'learning_rate': 1.322875816993464e-06, 'epoch': 10.04}
{'loss': 0.3613, 'grad_norm': 35.549468994140625, 'learning_rate': 1.3124183006535948e-06, 'epoch': 10.08}
{'loss': 0.2414, 'grad_norm': 2.1722872257232666, 'learning_rate': 1.3019607843137254e-06, 'epoch': 10.12}
{'loss': 0.091, 'grad_norm': 16.084609985351562, 'learning_rate': 1.2915032679738562e-06, 'epoch': 10.16}
{'loss': 0.0604, 'grad_norm': 8.946029663085938, 'learning_rate': 1.2810457516339867e-06, 'epoch': 10.2}
{'loss': 0.2004, 'grad_norm': 92.22409057617188, 'learning_rate': 1.2705882352941175e-06, 'epoch': 10.24}
{'loss': 0.085, 'grad_norm': 1.236768364906311, 'learning_rate': 1.2601307189542483e-06, 'epoch': 10.27}
{'loss': 0.2273, 'grad_norm': 17.48619270324707, 'learning_rate': 1.249673202614379

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.2166920155286789, 'eval_accuracy': 0.9339055793991416, 'eval_runtime': 179.5407, 'eval_samples_per_second': 6.489, 'eval_steps_per_second': 0.407, 'epoch': 11.0}
{'loss': 0.0839, 'grad_norm': 16.2960205078125, 'learning_rate': 1.061437908496732e-06, 'epoch': 11.02}
{'loss': 0.0575, 'grad_norm': 11.0881929397583, 'learning_rate': 1.0509803921568627e-06, 'epoch': 11.06}
{'loss': 0.1321, 'grad_norm': 92.45496368408203, 'learning_rate': 1.0405228758169935e-06, 'epoch': 11.1}
{'loss': 0.1811, 'grad_norm': 33.560585021972656, 'learning_rate': 1.030065359477124e-06, 'epoch': 11.14}
{'loss': 0.1623, 'grad_norm': 47.79589080810547, 'learning_rate': 1.0196078431372548e-06, 'epoch': 11.18}
{'loss': 0.1927, 'grad_norm': 21.562644958496094, 'learning_rate': 1.0091503267973856e-06, 'epoch': 11.22}
{'loss': 0.1292, 'grad_norm': 20.521638870239258, 'learning_rate': 9.986928104575162e-07, 'epoch': 11.25}
{'loss': 0.2115, 'grad_norm': 11.966313362121582, 'learning_rate': 9.88235294117647

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.21931400895118713, 'eval_accuracy': 0.936480686695279, 'eval_runtime': 179.1485, 'eval_samples_per_second': 6.503, 'eval_steps_per_second': 0.407, 'epoch': 12.0}
{'loss': 0.1633, 'grad_norm': 12.932236671447754, 'learning_rate': 7.895424836601306e-07, 'epoch': 12.04}
{'loss': 0.1787, 'grad_norm': 33.9129524230957, 'learning_rate': 7.790849673202614e-07, 'epoch': 12.08}
{'loss': 0.0743, 'grad_norm': 40.722164154052734, 'learning_rate': 7.686274509803921e-07, 'epoch': 12.12}
{'loss': 0.1596, 'grad_norm': 26.881332397460938, 'learning_rate': 7.581699346405228e-07, 'epoch': 12.16}
{'loss': 0.1615, 'grad_norm': 0.3181571960449219, 'learning_rate': 7.477124183006536e-07, 'epoch': 12.2}
{'loss': 0.1485, 'grad_norm': 0.9654801487922668, 'learning_rate': 7.372549019607843e-07, 'epoch': 12.24}
{'loss': 0.2435, 'grad_norm': 52.502227783203125, 'learning_rate': 7.26797385620915e-07, 'epoch': 12.27}
{'loss': 0.1498, 'grad_norm': 36.79524230957031, 'learning_rate': 7.163398692810458e

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.21930508315563202, 'eval_accuracy': 0.9407725321888412, 'eval_runtime': 179.8599, 'eval_samples_per_second': 6.477, 'eval_steps_per_second': 0.406, 'epoch': 13.0}
{'loss': 0.2113, 'grad_norm': 1.5805141925811768, 'learning_rate': 5.281045751633986e-07, 'epoch': 13.02}
{'loss': 0.1123, 'grad_norm': 10.298863410949707, 'learning_rate': 5.176470588235294e-07, 'epoch': 13.06}
{'loss': 0.1216, 'grad_norm': 16.241456985473633, 'learning_rate': 5.071895424836601e-07, 'epoch': 13.1}
{'loss': 0.1847, 'grad_norm': 24.810436248779297, 'learning_rate': 4.967320261437908e-07, 'epoch': 13.14}
{'loss': 0.0873, 'grad_norm': 19.310667037963867, 'learning_rate': 4.862745098039216e-07, 'epoch': 13.18}
{'loss': 0.1235, 'grad_norm': 18.36367416381836, 'learning_rate': 4.758169934640522e-07, 'epoch': 13.22}
{'loss': 0.1489, 'grad_norm': 83.51899719238281, 'learning_rate': 4.65359477124183e-07, 'epoch': 13.25}
{'loss': 0.2251, 'grad_norm': 32.3670654296875, 'learning_rate': 4.549019607843137e

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.23749525845050812, 'eval_accuracy': 0.9356223175965666, 'eval_runtime': 181.017, 'eval_samples_per_second': 6.436, 'eval_steps_per_second': 0.403, 'epoch': 14.0}
{'loss': 0.1755, 'grad_norm': 13.70283317565918, 'learning_rate': 2.5620915032679736e-07, 'epoch': 14.04}
{'loss': 0.1464, 'grad_norm': 47.662940979003906, 'learning_rate': 2.457516339869281e-07, 'epoch': 14.08}
{'loss': 0.0998, 'grad_norm': 3.9633285999298096, 'learning_rate': 2.352941176470588e-07, 'epoch': 14.12}
{'loss': 0.1944, 'grad_norm': 41.443214416503906, 'learning_rate': 2.2483660130718955e-07, 'epoch': 14.16}
{'loss': 0.1429, 'grad_norm': 85.75756072998047, 'learning_rate': 2.1437908496732023e-07, 'epoch': 14.2}
{'loss': 0.159, 'grad_norm': 3.4543466567993164, 'learning_rate': 2.0392156862745097e-07, 'epoch': 14.24}
{'loss': 0.2513, 'grad_norm': 4.192806243896484, 'learning_rate': 1.934640522875817e-07, 'epoch': 14.27}
{'loss': 0.1592, 'grad_norm': 101.54142761230469, 'learning_rate': 1.830065359477

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.2497621476650238, 'eval_accuracy': 0.9339055793991416, 'eval_runtime': 169.7735, 'eval_samples_per_second': 6.862, 'eval_steps_per_second': 0.43, 'epoch': 15.0}
{'train_runtime': 22380.228, 'train_samples_per_second': 2.733, 'train_steps_per_second': 0.171, 'train_loss': 0.20743973994371936, 'epoch': 15.0}


  0%|          | 0/37 [00:00<?, ?it/s]

Test Results: {'eval_loss': 0.20613574981689453, 'eval_accuracy': 0.936426116838488, 'eval_runtime': 86.1279, 'eval_samples_per_second': 6.757, 'eval_steps_per_second': 0.43, 'epoch': 15.0}


In [11]:
torch.save(model.state_dict(), 'model_parameters_1201.pth')

In [12]:
training_args = TrainingArguments(
    output_dir="./vit_model_1201",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=20,
    learning_rate=5e-9,
    weight_decay=0.01,
    save_total_limit=2,
    remove_unused_columns=False,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=None,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping],
)

trainer.train()
results = trainer.evaluate(dataset["test"])
print("Test Results:", results)

  trainer = Trainer(


  0%|          | 0/5100 [00:00<?, ?it/s]

{'loss': 0.0962, 'grad_norm': 0.6237013339996338, 'learning_rate': 4.990196078431373e-09, 'epoch': 0.04}
{'loss': 0.0526, 'grad_norm': 1.2592545747756958, 'learning_rate': 4.980392156862745e-09, 'epoch': 0.08}
{'loss': 0.1246, 'grad_norm': 12.886646270751953, 'learning_rate': 4.970588235294118e-09, 'epoch': 0.12}
{'loss': 0.1764, 'grad_norm': 1.2968604564666748, 'learning_rate': 4.96078431372549e-09, 'epoch': 0.16}
{'loss': 0.1424, 'grad_norm': 20.541128158569336, 'learning_rate': 4.950980392156863e-09, 'epoch': 0.2}
{'loss': 0.0992, 'grad_norm': 22.400251388549805, 'learning_rate': 4.941176470588235e-09, 'epoch': 0.24}
{'loss': 0.1788, 'grad_norm': 6.979392051696777, 'learning_rate': 4.931372549019608e-09, 'epoch': 0.27}
{'loss': 0.1596, 'grad_norm': 13.072405815124512, 'learning_rate': 4.92156862745098e-09, 'epoch': 0.31}
{'loss': 0.0989, 'grad_norm': 3.2910754680633545, 'learning_rate': 4.9117647058823525e-09, 'epoch': 0.35}
{'loss': 0.1727, 'grad_norm': 70.84709930419922, 'learning

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.2209848314523697, 'eval_accuracy': 0.9381974248927039, 'eval_runtime': 174.8566, 'eval_samples_per_second': 6.663, 'eval_steps_per_second': 0.417, 'epoch': 1.0}
{'loss': 0.159, 'grad_norm': 50.63440704345703, 'learning_rate': 4.745098039215686e-09, 'epoch': 1.02}
{'loss': 0.0831, 'grad_norm': 2.2748358249664307, 'learning_rate': 4.735294117647059e-09, 'epoch': 1.06}
{'loss': 0.1749, 'grad_norm': 82.53803253173828, 'learning_rate': 4.725490196078431e-09, 'epoch': 1.1}
{'loss': 0.1076, 'grad_norm': 23.359214782714844, 'learning_rate': 4.715686274509804e-09, 'epoch': 1.14}
{'loss': 0.1487, 'grad_norm': 4.227215766906738, 'learning_rate': 4.705882352941176e-09, 'epoch': 1.18}
{'loss': 0.1612, 'grad_norm': 41.28520202636719, 'learning_rate': 4.696078431372549e-09, 'epoch': 1.22}
{'loss': 0.1064, 'grad_norm': 28.58900260925293, 'learning_rate': 4.686274509803922e-09, 'epoch': 1.25}
{'loss': 0.1697, 'grad_norm': 4.768855571746826, 'learning_rate': 4.676470588235294e-09, 'epoch

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.2229243665933609, 'eval_accuracy': 0.936480686695279, 'eval_runtime': 170.6073, 'eval_samples_per_second': 6.829, 'eval_steps_per_second': 0.428, 'epoch': 2.0}
{'loss': 0.1077, 'grad_norm': 0.7562034130096436, 'learning_rate': 4.490196078431373e-09, 'epoch': 2.04}
{'loss': 0.1432, 'grad_norm': 25.7589111328125, 'learning_rate': 4.480392156862746e-09, 'epoch': 2.08}
{'loss': 0.1019, 'grad_norm': 15.096132278442383, 'learning_rate': 4.470588235294118e-09, 'epoch': 2.12}
{'loss': 0.1917, 'grad_norm': 29.30374526977539, 'learning_rate': 4.460784313725491e-09, 'epoch': 2.16}
{'loss': 0.0972, 'grad_norm': 6.981141090393066, 'learning_rate': 4.450980392156863e-09, 'epoch': 2.2}
{'loss': 0.1724, 'grad_norm': 19.163679122924805, 'learning_rate': 4.441176470588235e-09, 'epoch': 2.24}
{'loss': 0.091, 'grad_norm': 19.702045440673828, 'learning_rate': 4.431372549019608e-09, 'epoch': 2.27}
{'loss': 0.1063, 'grad_norm': 21.878149032592773, 'learning_rate': 4.42156862745098e-09, 'epoch

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.2239675670862198, 'eval_accuracy': 0.9356223175965666, 'eval_runtime': 172.7223, 'eval_samples_per_second': 6.745, 'eval_steps_per_second': 0.423, 'epoch': 3.0}
{'loss': 0.0867, 'grad_norm': 22.529979705810547, 'learning_rate': 4.2450980392156865e-09, 'epoch': 3.02}
{'loss': 0.1713, 'grad_norm': 11.340468406677246, 'learning_rate': 4.235294117647059e-09, 'epoch': 3.06}
{'loss': 0.0787, 'grad_norm': 15.864761352539062, 'learning_rate': 4.2254901960784315e-09, 'epoch': 3.1}
{'loss': 0.1989, 'grad_norm': 47.87004470825195, 'learning_rate': 4.215686274509804e-09, 'epoch': 3.14}
{'loss': 0.1534, 'grad_norm': 58.8902473449707, 'learning_rate': 4.205882352941177e-09, 'epoch': 3.18}
{'loss': 0.1157, 'grad_norm': 30.273406982421875, 'learning_rate': 4.1960784313725496e-09, 'epoch': 3.22}
{'loss': 0.1348, 'grad_norm': 23.036636352539062, 'learning_rate': 4.186274509803922e-09, 'epoch': 3.25}
{'loss': 0.1351, 'grad_norm': 0.27133795619010925, 'learning_rate': 4.176470588235295e-09

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.2247777283191681, 'eval_accuracy': 0.9356223175965666, 'eval_runtime': 174.3097, 'eval_samples_per_second': 6.684, 'eval_steps_per_second': 0.419, 'epoch': 4.0}
{'train_runtime': 6085.9954, 'train_samples_per_second': 13.398, 'train_steps_per_second': 0.838, 'train_loss': 0.1472910403694008, 'epoch': 4.0}


  0%|          | 0/37 [00:00<?, ?it/s]

Test Results: {'eval_loss': 0.19924920797348022, 'eval_accuracy': 0.9415807560137457, 'eval_runtime': 88.2956, 'eval_samples_per_second': 6.591, 'eval_steps_per_second': 0.419, 'epoch': 4.0}


In [13]:
torch.save(model.state_dict(), 'model_parameters_1202.pth')