In [None]:
from datasets import load_dataset
dataset = load_dataset("keremberke/chest-xray-classification",'full')
dataset.save_to_disk("./chest_xray_clasnsificatio")

In [None]:
import torch
import accelerate
from torch.utils.data import DataLoader
from torchvision.transforms import (
    Compose, Resize, ToTensor, Normalize
)
from transformers import ViTConfig, ViTForImageClassification, Trainer, TrainingArguments
from datasets import DatasetDict

# Load the dataset
dataset = DatasetDict.load_from_disk("./chest_xray_clasnsificatio")

image_size = 224
# Define a sequence of transformations for preprocessing images, including resizing, 
# converting to tensor, and normalizing pixel values
transforms = Compose([
    Resize((image_size, image_size)),
    ToTensor(),
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Function to preprocess data by applying transformations to images
def preprocess_data(example):
    example['pixel_values'] = transforms(example['image'])
    return example

# Apply the preprocessing function to all samples in the dataset
dataset = dataset.map(preprocess_data, batched=False)

# Remove unnecessary columns from the dataset
dataset = dataset.remove_columns(["image_file_path", "image"])

# Define a Vision Transformer (ViT) configuration with specific parameters
config = ViTConfig(
    image_size=image_size,
    num_channels=3,
    patch_size=16,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    num_labels=2
)

# Instantiate a Vision Transformer model for image classification using the specified configuration
model = ViTForImageClassification(config)

# Define training arguments for the Trainer, including batch size, epochs, and logging options
training_args = TrainingArguments(
    output_dir="./vit_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    learning_rate=5e-5,
    weight_decay=0.01,
    save_total_limit=2,
    remove_unused_columns=False,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False
)

from evaluate import load

# Load the 'accuracy' metric from the evaluation library
metric = load("accuracy")

# Function to compute evaluation metrics during training and validation
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = torch.argmax(torch.tensor(logits), dim=-1)
    return metric.compute(predictions=predictions.numpy(), references=labels)

# Create a Trainer object for training and evaluating the model
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=None,
    compute_metrics=compute_metrics
)

# Train the model
trainer.train()

# Evaluate the model on the test dataset
results = trainer.evaluate(dataset["test"])
print("Test Results:", results)


Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

  trainer = Trainer(
wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\19811\_netrc


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777286247, max=1.0…

  0%|          | 0/2550 [00:00<?, ?it/s]

{'loss': 1.2936, 'grad_norm': 23.22121238708496, 'learning_rate': 4.980392156862745e-05, 'epoch': 0.04}
{'loss': 0.7692, 'grad_norm': 0.4044587314128876, 'learning_rate': 4.960784313725491e-05, 'epoch': 0.08}
{'loss': 0.5978, 'grad_norm': 2.999021053314209, 'learning_rate': 4.9411764705882355e-05, 'epoch': 0.12}
{'loss': 0.5879, 'grad_norm': 2.198652982711792, 'learning_rate': 4.9215686274509804e-05, 'epoch': 0.16}
{'loss': 0.6272, 'grad_norm': 1.4779688119888306, 'learning_rate': 4.901960784313725e-05, 'epoch': 0.2}
{'loss': 0.5429, 'grad_norm': 5.54686164855957, 'learning_rate': 4.882352941176471e-05, 'epoch': 0.24}
{'loss': 0.5327, 'grad_norm': 7.827354431152344, 'learning_rate': 4.862745098039216e-05, 'epoch': 0.27}
{'loss': 0.6011, 'grad_norm': 8.274473190307617, 'learning_rate': 4.843137254901961e-05, 'epoch': 0.31}
{'loss': 0.6586, 'grad_norm': 9.161334037780762, 'learning_rate': 4.823529411764706e-05, 'epoch': 0.35}
{'loss': 0.8176, 'grad_norm': 5.5970377922058105, 'learning_ra

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.2792075276374817, 'eval_accuracy': 0.878969957081545, 'eval_runtime': 233.7612, 'eval_samples_per_second': 4.984, 'eval_steps_per_second': 0.312, 'epoch': 1.0}
{'loss': 0.2869, 'grad_norm': 15.761570930480957, 'learning_rate': 4.490196078431373e-05, 'epoch': 1.02}
{'loss': 0.4262, 'grad_norm': 15.251462936401367, 'learning_rate': 4.470588235294118e-05, 'epoch': 1.06}
{'loss': 0.4244, 'grad_norm': 14.319478034973145, 'learning_rate': 4.450980392156863e-05, 'epoch': 1.1}
{'loss': 0.3106, 'grad_norm': 8.241303443908691, 'learning_rate': 4.431372549019608e-05, 'epoch': 1.14}
{'loss': 0.3482, 'grad_norm': 11.514594078063965, 'learning_rate': 4.411764705882353e-05, 'epoch': 1.18}
{'loss': 0.5182, 'grad_norm': 15.224875450134277, 'learning_rate': 4.392156862745098e-05, 'epoch': 1.22}
{'loss': 0.3725, 'grad_norm': 1.137607455253601, 'learning_rate': 4.3725490196078437e-05, 'epoch': 1.25}
{'loss': 0.279, 'grad_norm': 3.010125160217285, 'learning_rate': 4.3529411764705885e-05, 'e

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.2987922728061676, 'eval_accuracy': 0.9012875536480687, 'eval_runtime': 147.1938, 'eval_samples_per_second': 7.915, 'eval_steps_per_second': 0.496, 'epoch': 2.0}
{'loss': 0.3358, 'grad_norm': 13.516194343566895, 'learning_rate': 3.980392156862745e-05, 'epoch': 2.04}
{'loss': 0.3068, 'grad_norm': 10.614214897155762, 'learning_rate': 3.96078431372549e-05, 'epoch': 2.08}
{'loss': 0.2537, 'grad_norm': 16.807022094726562, 'learning_rate': 3.9411764705882356e-05, 'epoch': 2.12}
{'loss': 0.4174, 'grad_norm': 23.811050415039062, 'learning_rate': 3.9215686274509805e-05, 'epoch': 2.16}
{'loss': 0.3191, 'grad_norm': 2.6947083473205566, 'learning_rate': 3.9019607843137254e-05, 'epoch': 2.2}
{'loss': 0.3418, 'grad_norm': 8.204344749450684, 'learning_rate': 3.882352941176471e-05, 'epoch': 2.24}
{'loss': 0.2481, 'grad_norm': 8.038412094116211, 'learning_rate': 3.862745098039216e-05, 'epoch': 2.27}
{'loss': 0.2757, 'grad_norm': 1.4780242443084717, 'learning_rate': 3.8431372549019614e-05

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.2000889629125595, 'eval_accuracy': 0.9227467811158798, 'eval_runtime': 146.7425, 'eval_samples_per_second': 7.939, 'eval_steps_per_second': 0.497, 'epoch': 3.0}
{'loss': 0.152, 'grad_norm': 2.9769539833068848, 'learning_rate': 3.4901960784313725e-05, 'epoch': 3.02}
{'loss': 0.2879, 'grad_norm': 4.8267974853515625, 'learning_rate': 3.470588235294118e-05, 'epoch': 3.06}
{'loss': 0.1729, 'grad_norm': 8.955961227416992, 'learning_rate': 3.450980392156863e-05, 'epoch': 3.1}
{'loss': 0.3143, 'grad_norm': 9.25610637664795, 'learning_rate': 3.431372549019608e-05, 'epoch': 3.14}
{'loss': 0.225, 'grad_norm': 7.599547863006592, 'learning_rate': 3.411764705882353e-05, 'epoch': 3.18}
{'loss': 0.2226, 'grad_norm': 7.530545711517334, 'learning_rate': 3.392156862745098e-05, 'epoch': 3.22}
{'loss': 0.163, 'grad_norm': 4.091729640960693, 'learning_rate': 3.372549019607844e-05, 'epoch': 3.25}
{'loss': 0.2517, 'grad_norm': 0.8671070337295532, 'learning_rate': 3.352941176470588e-05, 'epoch'

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.2321254462003708, 'eval_accuracy': 0.9201716738197425, 'eval_runtime': 247.3281, 'eval_samples_per_second': 4.71, 'eval_steps_per_second': 0.295, 'epoch': 4.0}
{'loss': 0.1285, 'grad_norm': 4.822474479675293, 'learning_rate': 2.9803921568627453e-05, 'epoch': 4.04}
{'loss': 0.3053, 'grad_norm': 10.9007568359375, 'learning_rate': 2.9607843137254905e-05, 'epoch': 4.08}
{'loss': 0.1908, 'grad_norm': 0.3530847728252411, 'learning_rate': 2.9411764705882354e-05, 'epoch': 4.12}
{'loss': 0.1414, 'grad_norm': 6.022669792175293, 'learning_rate': 2.9215686274509806e-05, 'epoch': 4.16}
{'loss': 0.1235, 'grad_norm': 0.2717658579349518, 'learning_rate': 2.9019607843137258e-05, 'epoch': 4.2}
{'loss': 0.2748, 'grad_norm': 2.3142852783203125, 'learning_rate': 2.8823529411764703e-05, 'epoch': 4.24}
{'loss': 0.208, 'grad_norm': 3.1117758750915527, 'learning_rate': 2.8627450980392155e-05, 'epoch': 4.27}
{'loss': 0.2414, 'grad_norm': 4.461186408996582, 'learning_rate': 2.8431372549019608e-05

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.24086377024650574, 'eval_accuracy': 0.9201716738197425, 'eval_runtime': 149.0203, 'eval_samples_per_second': 7.818, 'eval_steps_per_second': 0.49, 'epoch': 5.0}
{'loss': 0.1566, 'grad_norm': 12.72549057006836, 'learning_rate': 2.4901960784313726e-05, 'epoch': 5.02}
{'loss': 0.2349, 'grad_norm': 6.034121036529541, 'learning_rate': 2.4705882352941178e-05, 'epoch': 5.06}
{'loss': 0.1335, 'grad_norm': 13.013909339904785, 'learning_rate': 2.4509803921568626e-05, 'epoch': 5.1}
{'loss': 0.1498, 'grad_norm': 4.224095821380615, 'learning_rate': 2.431372549019608e-05, 'epoch': 5.14}
{'loss': 0.2132, 'grad_norm': 12.922297477722168, 'learning_rate': 2.411764705882353e-05, 'epoch': 5.18}
{'loss': 0.2218, 'grad_norm': 13.278406143188477, 'learning_rate': 2.3921568627450983e-05, 'epoch': 5.22}
{'loss': 0.2502, 'grad_norm': 6.417782783508301, 'learning_rate': 2.372549019607843e-05, 'epoch': 5.25}
{'loss': 0.1504, 'grad_norm': 4.598542213439941, 'learning_rate': 2.3529411764705884e-05,

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.17081515491008759, 'eval_accuracy': 0.9416309012875537, 'eval_runtime': 150.0562, 'eval_samples_per_second': 7.764, 'eval_steps_per_second': 0.486, 'epoch': 6.0}
{'loss': 0.0928, 'grad_norm': 6.786096096038818, 'learning_rate': 1.980392156862745e-05, 'epoch': 6.04}
{'loss': 0.1554, 'grad_norm': 11.25170612335205, 'learning_rate': 1.9607843137254903e-05, 'epoch': 6.08}
{'loss': 0.0553, 'grad_norm': 7.1980390548706055, 'learning_rate': 1.9411764705882355e-05, 'epoch': 6.12}
{'loss': 0.0135, 'grad_norm': 5.88771915435791, 'learning_rate': 1.9215686274509807e-05, 'epoch': 6.16}
{'loss': 0.2331, 'grad_norm': 0.7407482862472534, 'learning_rate': 1.9019607843137255e-05, 'epoch': 6.2}
{'loss': 0.2102, 'grad_norm': 2.0650699138641357, 'learning_rate': 1.8823529411764708e-05, 'epoch': 6.24}
{'loss': 0.1855, 'grad_norm': 4.044532775878906, 'learning_rate': 1.862745098039216e-05, 'epoch': 6.27}
{'loss': 0.1367, 'grad_norm': 2.246584177017212, 'learning_rate': 1.843137254901961e-05,

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.17084607481956482, 'eval_accuracy': 0.9416309012875537, 'eval_runtime': 150.6362, 'eval_samples_per_second': 7.734, 'eval_steps_per_second': 0.485, 'epoch': 7.0}
{'loss': 0.2365, 'grad_norm': 9.027932167053223, 'learning_rate': 1.4901960784313726e-05, 'epoch': 7.02}
{'loss': 0.1425, 'grad_norm': 0.7207277417182922, 'learning_rate': 1.4705882352941177e-05, 'epoch': 7.06}
{'loss': 0.103, 'grad_norm': 0.7626366019248962, 'learning_rate': 1.4509803921568629e-05, 'epoch': 7.1}
{'loss': 0.0641, 'grad_norm': 7.861717700958252, 'learning_rate': 1.4313725490196078e-05, 'epoch': 7.14}
{'loss': 0.1008, 'grad_norm': 10.903027534484863, 'learning_rate': 1.411764705882353e-05, 'epoch': 7.18}
{'loss': 0.1598, 'grad_norm': 31.84517478942871, 'learning_rate': 1.3921568627450982e-05, 'epoch': 7.22}
{'loss': 0.109, 'grad_norm': 3.7606818675994873, 'learning_rate': 1.3725490196078432e-05, 'epoch': 7.25}
{'loss': 0.1736, 'grad_norm': 0.11190846562385559, 'learning_rate': 1.3529411764705883e

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.19189943373203278, 'eval_accuracy': 0.934763948497854, 'eval_runtime': 182.6368, 'eval_samples_per_second': 6.379, 'eval_steps_per_second': 0.4, 'epoch': 8.0}
{'loss': 0.216, 'grad_norm': 6.726747035980225, 'learning_rate': 9.803921568627451e-06, 'epoch': 8.04}
{'loss': 0.0964, 'grad_norm': 0.8646385669708252, 'learning_rate': 9.607843137254903e-06, 'epoch': 8.08}
{'loss': 0.0608, 'grad_norm': 0.6351845860481262, 'learning_rate': 9.411764705882354e-06, 'epoch': 8.12}
{'loss': 0.1332, 'grad_norm': 6.478155612945557, 'learning_rate': 9.215686274509804e-06, 'epoch': 8.16}
{'loss': 0.0401, 'grad_norm': 0.06587941199541092, 'learning_rate': 9.019607843137255e-06, 'epoch': 8.2}
{'loss': 0.172, 'grad_norm': 3.740788459777832, 'learning_rate': 8.823529411764707e-06, 'epoch': 8.24}
{'loss': 0.1814, 'grad_norm': 0.48280563950538635, 'learning_rate': 8.627450980392157e-06, 'epoch': 8.27}
{'loss': 0.1314, 'grad_norm': 0.843451738357544, 'learning_rate': 8.43137254901961e-06, 'epoch

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.20923149585723877, 'eval_accuracy': 0.9407725321888412, 'eval_runtime': 150.5163, 'eval_samples_per_second': 7.74, 'eval_steps_per_second': 0.485, 'epoch': 9.0}
{'loss': 0.1528, 'grad_norm': 0.14838454127311707, 'learning_rate': 4.901960784313726e-06, 'epoch': 9.02}
{'loss': 0.0413, 'grad_norm': 15.815964698791504, 'learning_rate': 4.705882352941177e-06, 'epoch': 9.06}
{'loss': 0.1437, 'grad_norm': 7.116632461547852, 'learning_rate': 4.509803921568627e-06, 'epoch': 9.1}
{'loss': 0.1142, 'grad_norm': 28.419443130493164, 'learning_rate': 4.313725490196079e-06, 'epoch': 9.14}
{'loss': 0.0192, 'grad_norm': 0.368515282869339, 'learning_rate': 4.11764705882353e-06, 'epoch': 9.18}
{'loss': 0.0649, 'grad_norm': 0.19300584495067596, 'learning_rate': 3.92156862745098e-06, 'epoch': 9.22}
{'loss': 0.1388, 'grad_norm': 2.903005838394165, 'learning_rate': 3.7254901960784316e-06, 'epoch': 9.25}
{'loss': 0.096, 'grad_norm': 14.219625473022461, 'learning_rate': 3.5294117647058825e-06, '

  0%|          | 0/73 [00:00<?, ?it/s]

{'eval_loss': 0.20011577010154724, 'eval_accuracy': 0.944206008583691, 'eval_runtime': 208.2464, 'eval_samples_per_second': 5.594, 'eval_steps_per_second': 0.351, 'epoch': 10.0}
{'train_runtime': 14666.295, 'train_samples_per_second': 2.78, 'train_steps_per_second': 0.174, 'train_loss': 0.24141756545094883, 'epoch': 10.0}


  0%|          | 0/37 [00:00<?, ?it/s]

Test Results: {'eval_loss': 0.23524129390716553, 'eval_accuracy': 0.9415807560137457, 'eval_runtime': 90.1285, 'eval_samples_per_second': 6.457, 'eval_steps_per_second': 0.411, 'epoch': 10.0}


In [None]:
#save the model
model.save_pretrained("./vit_model")
from transformers import ViTForImageClassification

loaded_model = ViTForImageClassification.from_pretrained("./vit_model")

print(loaded_model)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe