In [84]:
import cv2
import numpy as np
fn = 'slices/subject1_top.tiff'
fn = 'patches/patch1.tiff'
def read_tiff(fn):
    im = cv2.imread(fn, -1)
    x = np.array(im)
    return x.astype(float)

def read_csv(fn):
    with open(fn,"r") as f:
        lines = [l.strip().split(",") for l in f.readlines()]
    return lines


In [221]:
import glob
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize)
class CTED_Dataset(Dataset):
    def __init__(self, split=None, transform=None):
        label_csv = read_csv("data/slice_labels.csv")
        severity_csv = read_csv("data/slice_severity.csv")

        slice_data = []
        label_data = []
        severity_data = []
        for (f1,label),(f2,severity) in zip(label_csv,severity_csv):
            im=read_tiff(f'data/slices/{f1}.tiff')
            slice_data.append(im)
            label_data.append(int(label)-1)
            severity_data.append(int(severity))
        slice_data = np.array(slice_data)
        mean, std = slice_data.mean(), slice_data.std()
        self.slice_data = torch.tensor((slice_data-mean)/std).unsqueeze(1).to(torch.float)
        self.label_data = torch.tensor(label_data)
        self.severity_data = torch.tensor(severity_data)
        if split is not None:
            self.slice_data = self.slice_data[split]
            self.label_data = self.label_data[split]
            self.severity_data = self.severity_data[split]
        self.transform = transform
        
    def __len__(self):
        return len(self.slice_data)
        
    def __getitem__(self, idx):
        
        img = self.slice_data[idx]
        if self.transform is not None:
            img = self.transform(img)
        return {"pixel_value": img,
                "label": self.label_data[idx]}

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_value"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    # severity = torch.tensor([example["severity"] for example in examples])
    
    return {"pixel_values": pixel_values, "labels": labels}

id2label = {0: "NT", 1: "CLE", 2: "PSE", 3: "PLE"}
label2id = {"NT":0, "CLE":1, "PSE":2, "PLE":3}
image_size = 256
_train_transforms = Compose([
            RandomResizedCrop(image_size),
            RandomHorizontalFlip()
        ])

_val_transforms = Compose([
            Resize(image_size),
            CenterCrop(image_size)
        ])

data_len = 115


In [238]:
from transformers import ViTForImageClassification, ViTConfig
config = ViTConfig(encoder_stride=16, 
                   patch_size=16,
                   image_size=image_size, 
                   hidden_size=384, 
                   num_attention_heads=6, 
                   num_channels=1, 
                   num_hidden_layers=4)
print(config)
config.id2label=id2label
config.label2id=label2id

model = ViTForImageClassification(config)

ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 384,
  "image_size": 256,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 6,
  "num_channels": 1,
  "num_hidden_layers": 4,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.34.0"
}



In [237]:
train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=4)
# valid_loader = DataLoader(valid_dataset, collate_fn=collate_fn, batch_size=4)
x=next(iter(train_loader))
model(x["pixel_values"], labels=x['labels'])

ImageClassifierOutput(loss=tensor(1.7364, grad_fn=<NllLossBackward0>), logits=tensor([[-0.4632,  0.3593, -0.0728, -0.3229],
        [-0.4845,  0.3374, -0.0304, -0.3221],
        [-0.4378,  0.2554, -0.1312, -0.2963],
        [-0.4240,  0.2677, -0.2632, -0.3057]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [247]:
from transformers import TrainingArguments, Trainer
from sklearn.metrics import accuracy_score

metric_name = "accuracy"

args = TrainingArguments(
    f"copd-train",
    save_strategy="steps",
    evaluation_strategy="steps",
    learning_rate=2e-5,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    num_train_epochs=300,
    weight_decay=0.01,
    logging_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
    remove_unused_columns=False,
)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return dict(accuracy=accuracy_score(predictions, labels))

In [251]:
args.save_pretrained(".")

AttributeError: 'TrainingArguments' object has no attribute 'save_pretrained'

In [248]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)

In [249]:
trainer.train()

Step,Training Loss,Validation Loss,Accuracy
100,1.1876,1.006977,0.6
200,1.178,1.060478,0.6
300,1.1492,1.119821,0.6
400,1.1592,1.14695,0.6
500,1.0714,1.108275,0.65
600,1.0342,0.850624,0.7
700,0.9801,0.895124,0.65
800,0.9458,0.850323,0.7
900,0.9419,0.88491,0.75
1000,0.9274,1.067866,0.65


KeyboardInterrupt: 