In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os


image_paths = []
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        path = os.path.join(dirname, filename)
        if path.endswith('png'):
            image_paths.append(path)
    

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
with open('/kaggle/input/data/test_list.txt') as t:
    test_paths = list(map(lambda z: z.strip(), t.readlines()))
    
with open('/kaggle/input/data/train_val_list.txt') as t:
    train_paths = list(map(lambda z: z.strip(), t.readlines()))
    
entries = pd.read_csv('/kaggle/input/data/Data_Entry_2017.csv', index_col='Image Index')

In [3]:
from PIL import Image
from sklearn.preprocessing import LabelEncoder

import tqdm

In [4]:
le = LabelEncoder()
entries['labels'] = le.fit_transform(entries['Finding Labels'])

In [5]:
j = []
for i in pd.unique(entries['Finding Labels']):
    j.extend(i.split('|'))
unique = np.unique(np.array(j)).tolist()

labels = dict(zip(unique, range(len(unique))))

In [6]:
from torchvision import transforms
import torch
trns = transforms.Compose(
    [
        transforms.Resize((280, 280)),
        transforms.ToTensor(),
        lambda t: torch.zeros(3, 280, 280) + t[0:1, :, :],
        lambda t: torch.vstack(
            (t[[1, 2], :, :] * -1 + t[[1, 2], :, :], t[0:1, :, :])
        )
    ]
)




def get_data(paths, l, u):
    X_data = []
    data_y = []
    for path in tqdm.tqdm(paths[l: u]):
        for image_path in image_paths:
            if image_path.endswith(path):
                y = torch.zeros(15)
                p = Image.open(image_path)
                t_x = trns(p)
                X_data.append(t_x)
                diseases = entries.loc[path]['Finding Labels'].split('|')
                for d in diseases:
                    y[labels[d]] = 1
                data_y.append(y)
                break
    return torch.stack(X_data), torch.stack(data_y)


# Model

In [7]:
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')

Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vits14_pretrain.pth
100%|██████████| 84.2M/84.2M [00:01<00:00, 62.0MB/s]


In [8]:
class CLF(torch.nn.Module):
    def __init__(self, encoder, num_classes=15, dim=384):
        super(CLF, self).__init__()
        self.encoder = encoder
        self.dim = dim
        self.num_classes = num_classes
        self.head = torch.nn.Linear(self.dim, self.num_classes)
#         self.local = torch.nn.Linear(self.dim, 4)
        
    def forward(self, x):
        x = self.encoder(x)
        return self.head(x)

In [9]:
def train(model, x, y, loss_fn, optimizer):
    model.train()
    
    optimizer.zero_grad()
    
    y_hat = model(x)
    
    loss = loss_fn(y_hat, y)
    
    loss.backward()
    
    optimizer.step()
    
    return loss.item()

In [10]:
@torch.no_grad()
def eval(model, x, y):
    model.eval()
    
    y_hat = model(x)
    
    return ((y_hat > 0) == y).float().mean(0)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
clf = CLF(dinov2_vits14).to(device)
opt = torch.optim.Adam(clf.parameters(), lr=1e-3)
loss_fn = torch.nn.BCEWithLogitsLoss()
# print(eval(clf, x, y))
batch_size = 16

for i in range(1000):
    e_loss = 0
    accs = []
    for i in range(len(train_paths)):
        j = i + 2500
        if j > len(train_paths):
            j = len(train_paths)
        print('Getting data')
        X_train, train_y = get_data(train_paths, i, j)
        print('Training')
        for x, y in tqdm.tqdm(zip(
            torch.split(X_train, split_size_or_sections=batch_size),
            torch.split(train_y, split_size_or_sections=batch_size)
        )):
    #     x = (torch.stack(_)[:5])
    #     y = torch.tensor(train_y[:5])
            e_loss += train(clf, x.to(device), y.to(device), loss_fn, opt)
    for i in range(len(test_paths)):
        j = i + 2500
        if j > len(test_paths):
            j = len(test_paths)
        print('Getting data')
        X_test, test_y = get_data(test_paths, i, j)
        print('Testing')
        for x, y in tqdm.tqdm(zip(
            torch.split(X_test, split_size_or_sections=batch_size),
            torch.split(test_y, split_size_or_sections=batch_size)
        )):
            
            acc = eval(clf, x.to(device), y.to(device))
            accs.append(acc)
    print('Epoch', i)
    print('Loss: ', e_loss)
    print(torch.stack(accs).mean(0))
#     break
    
    print('------------')
# dinov2_vits14

Getting data


100%|██████████| 2500/2500 [01:25<00:00, 29.41it/s]


Training


157it [00:33,  4.69it/s]


Getting data


100%|██████████| 2500/2500 [01:23<00:00, 29.80it/s]


Training


157it [00:33,  4.68it/s]


Getting data


100%|██████████| 2500/2500 [01:24<00:00, 29.60it/s]


Training


157it [00:33,  4.68it/s]


Getting data


100%|██████████| 2500/2500 [01:24<00:00, 29.68it/s]


Training


157it [00:33,  4.68it/s]


Getting data


100%|██████████| 2500/2500 [01:23<00:00, 29.78it/s]


Training


157it [00:33,  4.68it/s]


Getting data


100%|██████████| 2500/2500 [01:23<00:00, 29.81it/s]


Training


157it [00:33,  4.68it/s]


Getting data


100%|██████████| 2500/2500 [01:24<00:00, 29.58it/s]


Training


157it [00:33,  4.68it/s]


Getting data


100%|██████████| 2500/2500 [01:24<00:00, 29.67it/s]


Training


157it [00:33,  4.68it/s]


Getting data


100%|██████████| 2500/2500 [01:23<00:00, 29.93it/s]


Training


111it [00:23,  4.69it/s]