In [1]:
import os
import time
import random
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import torch
from torchvision import transforms, models
import PIL
from PIL import Image

from transformers import ViTImageProcessor, ViTForImageClassification, Trainer, TrainingArguments
import datasets
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from tqdm import tqdm

2023-04-17 01:14:22.382442: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

root_path = '/scratch/sss9772/early-detection-of-3d-printing-issues/'


In [3]:
model_name = "google/vit-base-patch16-224"

feature_extractor = ViTImageProcessor.from_pretrained(model_name)

img_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(0.5),
        transforms.RandomRotation(60),
        transforms.ColorJitter(brightness=(0.5,1.5),contrast=(1),saturation=(0.5,1.5),hue=(-0.1,0.1)),
        transforms.Resize((400,400), interpolation=transforms.InterpolationMode.NEAREST)
])

def vit_transform(example_batch):
    example_batch['image'] = [img_transform(img) for img in example_batch['image']]
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
    inputs['label'] = example_batch['label']
    return inputs

dataset = datasets.load_dataset(os.path.join(root_path, 'my_dataset.py'), num_proc=8, split='train', cache_dir="/scratch/sss9772/.cache")
dataset = dataset.cast_column("image", datasets.Image())
dataset = dataset.shuffle(seed=SEED)
dataset.set_format('torch')
dataset.set_transform(vit_transform)

train_test = dataset.train_test_split(0.1, stratify_by_column='label')

Found cached dataset my_dataset (/scratch/sss9772/.cache/my_dataset/default/0.0.0/71801c53399c6f25b1f58897c3a6286da73c00f60a6b6b86daa89653e3b29e8e)
Loading cached shuffled indices for dataset at /scratch/sss9772/.cache/my_dataset/default/0.0.0/71801c53399c6f25b1f58897c3a6286da73c00f60a6b6b86daa89653e3b29e8e/cache-c8317d096051a40a.arrow


In [4]:
model = ViTForImageClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True)
model.classifier = torch.nn.Linear(in_features=768, out_features=2, bias=True)
model.classifier.requires_grad = True

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
BATCH_SIZE = 128
number = 5
run_name = 'run_' + str(number)
training_args = TrainingArguments(
    output_dir = './output_' + str(number),
    num_train_epochs=4,
    overwrite_output_dir=True,
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size= 64,
    gradient_accumulation_steps = 4,    
    evaluation_strategy = 'steps',
    save_strategy='steps',
    eval_steps=50,
    save_steps=50,
    disable_tqdm = False, 
    warmup_steps=50,
    logging_steps = 50,
    logging_dir='./logs_' + str(number),
    remove_unused_columns=False,
    dataloader_num_workers = 8,
    run_name = run_name,
    report_to='wandb',
    load_best_model_at_end=True,
    save_total_limit=2,
    fp16=True
)

In [11]:
def compute_metrics(predictor):
    labels = predictor.label_ids
    preds = predictor.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    return {
        'precision' : precision,
        'recall' : recall,
        'f1' : f1,
        'accuracy': acc
    }

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }


class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        inp = inputs.get('pixel_values')
        outputs = model(inp)
        logits = outputs.get("logits")
        loss = torch.nn.CrossEntropyLoss()(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss    

trainer = CustomTrainer(
            model=model,
            args=training_args,
            data_collator=collate_fn,
            compute_metrics=compute_metrics,
            train_dataset=train_test['train'],
            eval_dataset=train_test['test'],
            tokenizer=feature_extractor
           )

In [None]:
trainer.train()



Step,Training Loss,Validation Loss,Precision,Recall,F1
50,0.0639,0.069612,0.866738,0.681345,0.740359
100,0.0604,0.055867,0.865492,0.770092,0.810038
150,0.0534,0.053156,0.88644,0.790355,0.831027
200,0.0467,0.043909,0.870889,0.845732,0.85783


In [None]:
# torch.save(model.state_dict(), 'vit_run_1.pth')

In [None]:
trainer.state.best_model_checkpoint

In [None]:
finetuned_model = ViTForImageClassification.from_pretrained(trainer.state.best_model_checkpoint)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def vit_transform_test(example_batch):
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
    inputs['image_path'] = example_batch['image_path']
    return inputs


In [None]:
test_dataset = datasets.load_dataset(os.path.join(root_path, 'my_dataset_test.py'), num_proc=8, split='test', cache_dir="/scratch/sss9772/.cache")
test_dataset = test_dataset.cast_column("image", datasets.Image())
test_dataset.set_format('torch')
test_dataset.set_transform(vit_transform_test)

In [None]:
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=16, num_workers = 8, shuffle=False)
finetuned_model.to(device)
finetuned_model.eval()
res = []
for i,data in enumerate(tqdm(test_dataloader)):
    pixel_values = data['pixel_values'].to(device)
    preds_probs = finetuned_model(pixel_values)
    preds_class = torch.argmax(preds_probs['logits'], dim=-1)
    preds_class = preds_class.cpu().detach().numpy()
    for i in range(len(data['image_path'])):
        res.append([data['image_path'][i], int(preds_class[i])])

sub_file = 'submission'+ str(time.time()) +'.csv'
df = pd.DataFrame(res, columns = ['img_path', 'has_under_extrusion'])
df.to_csv(sub_file, index=None)