In [2]:
import os
import glob
import shutil
import copy

import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import albumentations as A
from albumentations.pytorch import ToTensorV2

from torch.utils.data import Dataset, DataLoader
from torchmetrics import Accuracy
from torchvision import models, transforms

from sklearn.model_selection import KFold, train_test_split

In [3]:
labels_key = {
    "KL01": 0,
    "KL234": 1
}

In [4]:
class ClassificationDataset(Dataset):
    def __init__(self, images, label_key, transform=None):
        self.images = images
        self.labels = label_key
        self.transforms = transform
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[img.split("/")[-2]]
        img = np.array(Image.open(img).convert("RGB"))
        if self.transforms is not None:
            img = self.transforms(image=img)["image"]
            
        return img.float(), torch.tensor(label).float()

In [5]:
def get_model():
    import copy
    model = copy.deepcopy(models.resnet18(weights=models.ResNet18_Weights.DEFAULT))
    # for param in model.parameters():
    #     param.requires_grad_ = False
    model.fc = nn.Linear(model.fc.in_features, 1)
    model = model.to(device)

    return model

In [8]:
KL01_real = np.array(list(glob.iglob("/data_vault/hexai/KL01_KL234_Real/KL01/**")))
KL234_real = np.array(list(glob.iglob("/data_vault/hexai/KL01_KL234_Real/KL234/**")))
all_real =  np.concatenate([KL01_real, KL234_real])

KL01_fake = np.array(list(glob.iglob("/data_vault/hexai/SyntheticKneeImages/KL01/**")))
KL234_fake = np.array(list(glob.iglob("/data_vault/hexai/SyntheticKneeImages/KL234/**")))
all_fake =  np.concatenate([np.random.choice(KL01_fake, len(KL01_real)), np.random.choice(KL234_fake, len(KL234_real))])

In [20]:
train_fake, test_fake = train_test_split(all_fake, random_state=42)
train_fake, valid_fake = train_test_split(train_fake, random_state=42)

augmentations = A.Compose([A.Resize(224, 224), ToTensorV2()])

train_dataset = ClassificationDataset(train_fake, labels_key, transform=augmentations)
valid_dataset = ClassificationDataset(valid_fake, labels_key, transform=augmentations)
test_dataset = ClassificationDataset(test_fake, labels_key, transform=augmentations)

train_dataloader = DataLoader(train_dataset, batch_size=64)
valid_dataloader = DataLoader(valid_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = get_model()
optimizer = optim.Adam(model.parameters(), lr=5e-04)

In [21]:
from tqdm import tqdm
best_model = None
best_acc = 0.
fake_accuracies = []

criterion = nn.BCELoss()
acc_metric = Accuracy(task="binary").to(device)

for epoch in range(2):
    running_loss = []
    running_acc = []
    print(f"Epoch {epoch}")
    model.train()
    for img, label in tqdm(train_dataloader):
        optimizer.zero_grad()
        img = img.to(device)
        label = label.to(device).float()
        out = torch.sigmoid(model(img)).squeeze(dim=-1)
        loss = criterion(out, label)
        loss.backward()
        optimizer.step()
        running_loss.append(loss.item())
        running_acc.append(acc_metric(out, label).item())
    
    print(f"Train Loss: {np.mean(running_loss)}")
    print(f"Train Accuracy: {np.mean(running_acc)}")
    
    model.eval()
    running_loss = []
    running_acc = []
    for img, label in valid_dataloader:        
        img = img.to(device)
        label = label.to(device).float()
        out = torch.sigmoid(model(img.to(device))).squeeze(dim=-1)
        running_acc.append(acc_metric(out, label).item())
    val_acc =  np.mean(running_acc)
    print(f"Val. Accuracy: {val_acc}")

    if val_acc > best_acc:
        best_acc = val_acc
        best_model = model

running_acc = []
for img, label in test_dataloader:        
    img = img.to(device)
    label = label.to(device).float()
    out = model(img.to(device)).squeeze(dim=-1)
    running_acc.append(acc_metric(out, label).item())
print(f"Test Accuracy: {np.mean(running_acc)}")


Epoch 0


100%|██████████| 66/66 [00:42<00:00,  1.54it/s]


Train Loss: 0.040508724184870844
Train Accuracy: 0.978219696969697
Val. Accuracy: 1.0
Epoch 1


100%|██████████| 66/66 [00:36<00:00,  1.82it/s]


Train Loss: 0.000187539793070314
Train Accuracy: 1.0
Val. Accuracy: 1.0
Test Accuracy: 1.0


In [22]:
KL01_real = np.array(list(glob.iglob("/data_vault/hexai/KL01_KL234_Real/KL01/**")))
KL234_real = np.array(list(glob.iglob("/data_vault/hexai/KL01_KL234_Real/KL234/**")))
all_real =  np.concatenate([KL01_real, KL234_real])

test_dataset = ClassificationDataset(all_real, labels_key, transform=augmentations)
test_dataloader = DataLoader(test_dataset, batch_size=64)


In [23]:
train_real, test_real= train_test_split(all_real, random_state=42, test_size=0.9)
train_real, valid_real= train_test_split(train_real, random_state=42, test_size=0.2)

train_dataset_real = ClassificationDataset(train_real, labels_key, transform=augmentations)
valid_dataset_real = ClassificationDataset(valid_real, labels_key, transform=augmentations)
test_dataset_real = ClassificationDataset(test_real, labels_key, transform=augmentations)

train_dataloader_real = DataLoader(train_dataset_real, batch_size=64)
valid_dataloader_real = DataLoader(valid_dataset_real, batch_size=64)
test_dataloader_real = DataLoader(test_dataset_real, batch_size=64)

In [28]:
best_model_v2 = None
best_acc = 0.
model = copy.deepcopy(best_model)
for param in model.parameters():
    param.requires_grad_ = False
model.fc.requires_grad_ = True
model=model.cuda()

optimizer = optim.Adam(model.parameters(), lr=5e-4)

for epoch in range(20):
    running_loss = []
    running_acc = []
    print(f"Epoch {epoch}")
    model.train()
    for img, label in tqdm(train_dataloader_real):
        optimizer.zero_grad()
        img = img.to(device)
        label = label.to(device).float()
        out = torch.sigmoid(model(img)).squeeze(dim=-1)
        loss = criterion(out, label)
        loss.backward()
        optimizer.step()
        running_loss.append(loss.item())
        running_acc.append(acc_metric(out, label).item())
    
    print(f"Train Loss: {np.mean(running_loss)}")
    print(f"Train Accuracy: {np.mean(running_acc)}")
    
    model.eval()
    running_loss = []
    running_acc = []
    for img, label in valid_dataloader_real:        
        img = img.to(device)
        label = label.to(device).float()
        out = torch.sigmoid(model(img.to(device))).squeeze(dim=-1)
        running_acc.append(acc_metric(out, label).item())
    val_acc =  np.mean(running_acc)
    print(f"Val. Accuracy: {val_acc}")

    if val_acc > best_acc:
        best_acc = val_acc
        best_model_v2 = model

running_acc = []
for img, label in test_dataloader_real:        
    img = img.to(device)
    label = label.to(device).float()
    out = best_model_v2(img.to(device)).squeeze(dim=-1)
    running_acc.append(acc_metric(out, label).item())
print(f"Test Accuracy: {np.mean(running_acc)}")


Epoch 0


  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:01<00:00,  6.98it/s]


Train Loss: 1.464504098892212
Train Accuracy: 0.621875
Val. Accuracy: 0.6747159163157145
Epoch 1


100%|██████████| 10/10 [00:01<00:00,  7.30it/s]


Train Loss: 0.6263467729091644
Train Accuracy: 0.7603124976158142
Val. Accuracy: 0.7414772709210714
Epoch 2


100%|██████████| 10/10 [00:01<00:00,  7.01it/s]


Train Loss: 0.4163625329732895
Train Accuracy: 0.8171875
Val. Accuracy: 0.7930871248245239
Epoch 3


100%|██████████| 10/10 [00:01<00:00,  7.33it/s]


Train Loss: 0.3081431984901428
Train Accuracy: 0.86875
Val. Accuracy: 0.7215909163157145
Epoch 4


100%|██████████| 10/10 [00:01<00:00,  7.28it/s]


Train Loss: 0.2362002246081829
Train Accuracy: 0.909375
Val. Accuracy: 0.767518937587738
Epoch 5


100%|██████████| 10/10 [00:01<00:00,  7.30it/s]


Train Loss: 0.17511709704995154
Train Accuracy: 0.95
Val. Accuracy: 0.7878787914911906
Epoch 6


100%|██████████| 10/10 [00:01<00:00,  7.31it/s]


Train Loss: 0.12164773978292942
Train Accuracy: 0.975
Val. Accuracy: 0.8035037914911906
Epoch 7


100%|██████████| 10/10 [00:01<00:00,  7.31it/s]


Train Loss: 0.07959169521927834
Train Accuracy: 0.996875
Val. Accuracy: 0.8035037914911906
Epoch 8


100%|██████████| 10/10 [00:01<00:00,  7.32it/s]


Train Loss: 0.049932716973125936
Train Accuracy: 0.9984375
Val. Accuracy: 0.8035037914911906
Epoch 9


100%|██████████| 10/10 [00:01<00:00,  7.32it/s]


Train Loss: 0.03172651063650846
Train Accuracy: 1.0
Val. Accuracy: 0.8087121248245239
Epoch 10


100%|██████████| 10/10 [00:01<00:00,  7.33it/s]


Train Loss: 0.021310441475361586
Train Accuracy: 1.0
Val. Accuracy: 0.8035037914911906
Epoch 11


100%|██████████| 10/10 [00:01<00:00,  7.31it/s]


Train Loss: 0.015263016242533923
Train Accuracy: 1.0
Val. Accuracy: 0.8087121248245239
Epoch 12


100%|██████████| 10/10 [00:01<00:00,  7.31it/s]


Train Loss: 0.011573695112019777
Train Accuracy: 1.0
Val. Accuracy: 0.8035037914911906
Epoch 13


100%|██████████| 10/10 [00:01<00:00,  7.32it/s]


Train Loss: 0.009183799708262086
Train Accuracy: 1.0
Val. Accuracy: 0.8035037914911906
Epoch 14


100%|██████████| 10/10 [00:01<00:00,  7.00it/s]


Train Loss: 0.007547787390649319
Train Accuracy: 1.0
Val. Accuracy: 0.8035037914911906
Epoch 15


100%|██████████| 10/10 [00:01<00:00,  7.33it/s]


Train Loss: 0.006367328623309731
Train Accuracy: 1.0
Val. Accuracy: 0.8035037914911906
Epoch 16


100%|██████████| 10/10 [00:01<00:00,  6.86it/s]


Train Loss: 0.005482796858996153
Train Accuracy: 1.0
Val. Accuracy: 0.8035037914911906
Epoch 17


100%|██████████| 10/10 [00:01<00:00,  7.23it/s]


Train Loss: 0.004797426471486688
Train Accuracy: 1.0
Val. Accuracy: 0.7883522709210714
Epoch 18


100%|██████████| 10/10 [00:01<00:00,  7.29it/s]


Train Loss: 0.004251281195320189
Train Accuracy: 1.0
Val. Accuracy: 0.7883522709210714
Epoch 19


100%|██████████| 10/10 [00:01<00:00,  7.30it/s]


Train Loss: 0.0038069609319791196
Train Accuracy: 1.0
Val. Accuracy: 0.7883522709210714
Test Accuracy: 0.7589134380930946


In [14]:
torch.save(best_model_v2, "resnet18_finetuned_KL.pt")

In [15]:
train_real, test_real= train_test_split(all_real, random_state=42)
train_real, valid_real= train_test_split(train_real, random_state=42)

train_dataset_real = ClassificationDataset(train_real, labels_key, transform=augmentations)
valid_dataset_real = ClassificationDataset(valid_real, labels_key, transform=augmentations)
test_dataset_real = ClassificationDataset(test_real, labels_key, transform=augmentations)

train_dataloader = DataLoader(train_dataset_real, batch_size=64)
valid_dataloader = DataLoader(valid_dataset_real, batch_size=64)
test_dataloader = DataLoader(test_dataset_real, batch_size=64)

In [17]:
from tqdm import tqdm
best_model = None
best_acc = 0.
fake_accuracies = []
model = get_model()

criterion = nn.BCELoss()
acc_metric = Accuracy(task="binary").to(device)
optimizer = optim.Adam(model.parameters(), lr=5e-4)

for epoch in range(20):
    running_loss = []
    running_acc = []
    print(f"Epoch {epoch}")
    model.train()
    for img, label in tqdm(train_dataloader):
        optimizer.zero_grad()
        img = img.to(device)
        label = label.to(device).float()
        out = torch.sigmoid(model(img)).squeeze(dim=-1)
        loss = criterion(out, label)
        loss.backward()
        optimizer.step()
        running_loss.append(loss.item())
        running_acc.append(acc_metric(out, label).item())
    
    print(f"Train Loss: {np.mean(running_loss)}")
    print(f"Train Accuracy: {np.mean(running_acc)}")
    
    model.eval()
    running_loss = []
    running_acc = []
    for img, label in valid_dataloader:        
        img = img.to(device)
        label = label.to(device).float()
        out = torch.sigmoid(model(img.to(device))).squeeze(dim=-1)
        running_acc.append(acc_metric(out, label).item())
    val_acc =  np.mean(running_acc)
    print(f"Val. Accuracy: {val_acc}")

    if val_acc > best_acc:
        best_acc = val_acc
        best_model = model




Epoch 0


100%|██████████| 66/66 [00:10<00:00,  6.56it/s]


Train Loss: 0.49566961644273816
Train Accuracy: 0.7607323229312897
Val. Accuracy: 0.8195118795741688
Epoch 1


100%|██████████| 66/66 [00:10<00:00,  6.51it/s]


Train Loss: 0.3249477967619896
Train Accuracy: 0.8623737376747709
Val. Accuracy: 0.7998708676208149
Epoch 2


100%|██████████| 66/66 [00:10<00:00,  6.51it/s]


Train Loss: 0.24797720247597405
Train Accuracy: 0.8958859425602537
Val. Accuracy: 0.8077866733074188
Epoch 3


100%|██████████| 66/66 [00:10<00:00,  6.16it/s]


Train Loss: 0.1898858528019804
Train Accuracy: 0.9267150669386892
Val. Accuracy: 0.8063791312954642
Epoch 4


100%|██████████| 66/66 [00:10<00:00,  6.51it/s]


Train Loss: 0.1821269103410569
Train Accuracy: 0.9285037878787878
Val. Accuracy: 0.7490831613540649
Epoch 5


100%|██████████| 66/66 [00:10<00:00,  6.59it/s]


Train Loss: 0.11033567616885359
Train Accuracy: 0.9576231060606061
Val. Accuracy: 0.7990315095944838
Epoch 6


100%|██████████| 66/66 [00:10<00:00,  6.55it/s]


Train Loss: 0.12271814336153594
Train Accuracy: 0.9505208333333334
Val. Accuracy: 0.7893336767500098
Epoch 7


100%|██████████| 66/66 [00:10<00:00,  6.49it/s]


Train Loss: 0.09731369228525595
Train Accuracy: 0.9611742424242424
Val. Accuracy: 0.8165418397296559
Epoch 8


100%|██████████| 66/66 [00:10<00:00,  6.51it/s]


Train Loss: 0.051676602004039465
Train Accuracy: 0.9820075757575758
Val. Accuracy: 0.8163223131136461
Epoch 9


100%|██████████| 66/66 [00:10<00:00,  6.57it/s]


Train Loss: 0.02219417081285042
Train Accuracy: 0.9931870789238901
Val. Accuracy: 0.843775827776302
Epoch 10


100%|██████████| 66/66 [00:10<00:00,  6.58it/s]


Train Loss: 0.01772012641110147
Train Accuracy: 0.9945549242424242
Val. Accuracy: 0.8306430794975974
Epoch 11


100%|██████████| 66/66 [00:10<00:00,  6.51it/s]


Train Loss: 0.021432209594146996
Train Accuracy: 0.9933712121212122
Val. Accuracy: 0.8466167368672111
Epoch 12


100%|██████████| 66/66 [00:10<00:00,  6.57it/s]


Train Loss: 0.04539207039245715
Train Accuracy: 0.9813499577117689
Val. Accuracy: 0.7908703522248701
Epoch 13


100%|██████████| 66/66 [00:10<00:00,  6.56it/s]


Train Loss: 0.02996207151392644
Train Accuracy: 0.9919507575757576
Val. Accuracy: 0.8138300613923506
Epoch 14


100%|██████████| 66/66 [00:10<00:00,  6.28it/s]


Train Loss: 0.022443739453923296
Train Accuracy: 0.9921875
Val. Accuracy: 0.8385717977177013
Epoch 15


100%|██████████| 66/66 [00:10<00:00,  6.52it/s]


Train Loss: 0.013205095165176317
Train Accuracy: 0.9957386363636364
Val. Accuracy: 0.8481534096327695
Epoch 16


100%|██████████| 66/66 [00:10<00:00,  6.57it/s]


Train Loss: 0.008214052041170582
Train Accuracy: 0.9967382152875265
Val. Accuracy: 0.83986311880025
Epoch 17


100%|██████████| 66/66 [00:10<00:00,  6.59it/s]


Train Loss: 0.02112949200267339
Train Accuracy: 0.9926609848484849
Val. Accuracy: 0.82423811880025
Epoch 18


100%|██████████| 66/66 [00:10<00:00,  6.46it/s]


Train Loss: 0.030942943469254358
Train Accuracy: 0.9888731060606061
Val. Accuracy: 0.8409349186853929
Epoch 19


100%|██████████| 66/66 [00:10<00:00,  6.60it/s]


Train Loss: 0.025718428496144374
Train Accuracy: 0.9907670454545454
Val. Accuracy: 0.8357308886267922
Test Accuracy: 0.8352083325386047


In [18]:
running_acc = []
for img, label in test_dataloader:        
    img = img.to(device)
    label = label.to(device).float()
    out = best_model(img.to(device)).squeeze(dim=-1)
    running_acc.append(acc_metric(out, label).item())
print(f"Test Accuracy: {np.mean(running_acc)}")

Test Accuracy: 0.8352083325386047


In [19]:
torch.save(best_model, "resnet18_KL_real.pt")