In [28]:
!pip install --upgrade pip
!pip install datasets torch transformers torchvision torchaudio requests "urllib3<2" scikit-learn matplotlib



In [29]:

import torch
from datasets import load_dataset, load_metric
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, TrainingArguments, Trainer, ResNetConfig, ResNetForImageClassification
import numpy as np


In [30]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
x = torch.ones(1).to(device)
print(x)

tensor([1.], device='mps:0')


In [31]:
def load_data():
    d_types = ['train','test','validation']
    datasets = []
    for d in d_types:
        datasets.append(load_dataset(
            'trpakov/chest-xray-classification',
            'full',
            split=d,
        ))
    return datasets

dataset_train, dataset_test, dataset_valid = load_data()
dataset_train


Found cached dataset chest-xray-classification (/Users/tommyluu/.cache/huggingface/datasets/trpakov___chest-xray-classification/full/1.0.0/395971eacf4f6d99e8b95f965fa0a8af5a31a5118529992bba3f60c4a7af1b92)
Found cached dataset chest-xray-classification (/Users/tommyluu/.cache/huggingface/datasets/trpakov___chest-xray-classification/full/1.0.0/395971eacf4f6d99e8b95f965fa0a8af5a31a5118529992bba3f60c4a7af1b92)
Found cached dataset chest-xray-classification (/Users/tommyluu/.cache/huggingface/datasets/trpakov___chest-xray-classification/full/1.0.0/395971eacf4f6d99e8b95f965fa0a8af5a31a5118529992bba3f60c4a7af1b92)


Dataset({
    features: ['image_file_path', 'image', 'labels'],
    num_rows: 12230
})

In [32]:
# check how many labels/number of classes
num_classes = len(set(dataset_train['labels']))
labels = dataset_train.features['labels'].names
num_classes, labels

(2, ['PNEUMONIA', 'NORMAL'])

In [33]:
# import feature extraction model
model_id = 'microsoft/resnet-50'

feature_extractor = AutoFeatureExtractor.from_pretrained(
    model_id
)
config = ResNetConfig()
config.num_labels = num_classes
model = ResNetForImageClassification(config)
model = model.from_pretrained(
    model_id,  # classification head
).train().to(device)

In [34]:
def preprocess(batch):
    # take a list of PIL images and turn them to pixel values
    inputs = feature_extractor(
        batch['image'],
        return_tensors='pt'
    )
    # include the labels
    inputs['labels'] = batch['labels']
    return inputs

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

# transform the training dataset
prepared_train = dataset_train.with_transform(preprocess)
# ... and the testing dataset
prepared_test = dataset_test.with_transform(preprocess)
# ... and the validation dataset
prepared_valid = dataset_valid.with_transform(preprocess)

In [35]:
import numpy as np
from datasets import load_metric

# accuracy metric
metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(
        predictions=np.argmax(p.predictions, axis=1),
        references=p.label_ids
    )


In [36]:
from transformers import Trainer

output_dir = './resnet_model'

# training the model
training_args = TrainingArguments(
  output_dir,
  per_device_train_batch_size=32,
  evaluation_strategy="steps",
  num_train_epochs=2,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=0.001,
  save_total_limit=2,
  optim='adamw_torch',
  remove_unused_columns=False,
  push_to_hub=False,
  load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_train,
    eval_dataset=prepared_test,
    tokenizer=feature_extractor,
)

#change resume_from_checkpoint to True if you want to resume training
train_results = trainer.train(resume_from_checkpoint=False)
# save tokenizer with the model
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
# save the trainer state
trainer.save_state()
# evaluate with validation
metrics = trainer.evaluate(prepared_valid)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)


  0%|          | 0/766 [02:42<?, ?it/s]
  1%|▏         | 10/766 [01:07<1:21:48,  6.49s/it]

{'loss': 4.715, 'learning_rate': 0.0009869451697127939, 'epoch': 0.03}


  3%|▎         | 20/766 [02:12<1:20:14,  6.45s/it]

{'loss': 0.5802, 'learning_rate': 0.0009738903394255875, 'epoch': 0.05}


  4%|▍         | 30/766 [03:16<1:18:20,  6.39s/it]

{'loss': 0.2781, 'learning_rate': 0.0009608355091383812, 'epoch': 0.08}


  5%|▌         | 40/766 [04:19<1:16:51,  6.35s/it]

{'loss': 0.2414, 'learning_rate': 0.0009477806788511749, 'epoch': 0.1}


  7%|▋         | 50/766 [05:24<1:18:05,  6.54s/it]

{'loss': 0.1592, 'learning_rate': 0.0009347258485639687, 'epoch': 0.13}


  8%|▊         | 60/766 [06:28<1:14:35,  6.34s/it]

{'loss': 0.1929, 'learning_rate': 0.0009216710182767625, 'epoch': 0.16}


  9%|▉         | 70/766 [07:32<1:14:19,  6.41s/it]

{'loss': 0.1253, 'learning_rate': 0.0009086161879895561, 'epoch': 0.18}


 10%|█         | 80/766 [08:36<1:12:31,  6.34s/it]

{'loss': 0.219, 'learning_rate': 0.00089556135770235, 'epoch': 0.21}


 12%|█▏        | 90/766 [09:39<1:11:42,  6.37s/it]

{'loss': 0.1913, 'learning_rate': 0.0008825065274151436, 'epoch': 0.23}


 13%|█▎        | 100/766 [10:45<1:11:52,  6.48s/it]

{'loss': 0.1239, 'learning_rate': 0.0008694516971279373, 'epoch': 0.26}


                                                   
 13%|█▎        | 100/766 [11:03<1:11:52,  6.48s/it]

{'eval_loss': 0.20632897317409515, 'eval_accuracy': 0.9243986254295533, 'eval_runtime': 17.3052, 'eval_samples_per_second': 33.631, 'eval_steps_per_second': 4.218, 'epoch': 0.26}


 14%|█▍        | 110/766 [12:10<1:16:15,  6.98s/it]

{'loss': 0.1571, 'learning_rate': 0.0008563968668407311, 'epoch': 0.29}


 16%|█▌        | 120/766 [13:16<1:10:12,  6.52s/it]

{'loss': 0.2204, 'learning_rate': 0.0008433420365535248, 'epoch': 0.31}


 17%|█▋        | 130/766 [14:24<1:11:36,  6.76s/it]

{'loss': 0.1548, 'learning_rate': 0.0008302872062663186, 'epoch': 0.34}


 17%|█▋        | 133/766 [14:44<1:10:27,  6.68s/it]

In [None]:
from matplotlib import pyplot as plt

x = [log['step'] for log in trainer.state.log_history if 'loss' in log]
y = [log['loss'] for log in trainer.state.log_history if 'loss' in log]


# y axis is loss
# x axis is step
plt.plot(x, y)
plt.xlabel('step')
plt.ylabel('loss')
plt.title('ResNet50: Loss over train steps')
